diff --git a/backend/app/api/v1/configs.py b/backend/app/api/v1/configs.py index c5be307..067b6a0 100644 --- a/backend/app/api/v1/configs.py +++ b/backend/app/api/v1/configs.py @@ -156,3 +156,140 @@ async def set_default_config( """设为默认配置""" ConfigService.set_default_config(db, config_id) return ResponseModel(message="设置成功") + + +@router.get("/system", response_model=ResponseModel[dict]) +async def get_system_configs( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取系统配置(数据库、Redis等)""" + configs = { + "database": ConfigService.get_system_config(db, "DATABASE_URL") or "sqlite:///./amazing_data.db", + "redis": ConfigService.get_system_config(db, "REDIS_URL") or "redis://localhost:6379/0" + } + return ResponseModel(data=configs) + + +@router.put("/system", response_model=ResponseModel) +async def update_system_configs( + configs: dict, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新系统配置""" + if "database" in configs: + ConfigService.set_system_config( + db, + "DATABASE_URL", + configs["database"], + "数据库连接URL" + ) + + if "redis" in configs: + ConfigService.set_system_config( + db, + "REDIS_URL", + configs["redis"], + "Redis连接URL" + ) + + return ResponseModel(message="更新成功") + + +@router.post("/system/test", response_model=ResponseModel[dict]) +async def test_system_connection( + configs: dict, + current_user: User = Depends(get_current_user) +): + """测试系统连接(数据库和Redis)""" + import sqlalchemy + import redis + + result = { + "database": False, + "redis": False + } + + # 测试数据库连接 + if "database" in configs: + try: + engine = sqlalchemy.create_engine( + configs["database"], + connect_args={"check_same_thread": False} if "sqlite" in configs["database"] else {} + ) + with engine.connect() as conn: + result["database"] = True + except Exception as e: + pass + + # 测试Redis连接 + if "redis" in configs: + try: + redis_client = redis.from_url(configs["redis"]) + redis_client.ping() + result["redis"] = True + except Exception as e: + pass + + return ResponseModel(data=result) + + +@router.post("/system/init", response_model=ResponseModel[dict]) +async def init_system_database( + current_user: User = Depends(get_current_user) +): + """初始化数据库结构""" + try: + from app.db.session import init_db + init_db() + return ResponseModel(data={"success": True}) + except Exception as e: + return ResponseModel( + code=1001, + message=str(e), + data={"success": False} + ) + + +@router.get("/system/structure", response_model=ResponseModel[dict]) +async def check_database_structure( + current_user: User = Depends(get_current_user) +): + """检测数据库结构是否完整""" + try: + from sqlalchemy import inspect + from app.db.session import engine + + inspector = inspect(engine) + existing_tables = inspector.get_table_names() + + # 检查必要的表是否存在 + required_tables = [ + 'users', + 'sdk_configs', + 'system_configs', + 'stock_kline_daily', + 'stock_kline_minute', + 'future_kline_daily', + 'future_kline_minute', + 'cache_tasks', + 'stock_basic', + 'index_basic', + 'index_trade' + ] + + missing_tables = [table for table in required_tables if table not in existing_tables] + complete = len(missing_tables) == 0 + + return ResponseModel(data={ + "complete": complete, + "missing_tables": missing_tables, + "existing_tables": existing_tables + }) + except Exception as e: + return ResponseModel( + code=1001, + message=str(e), + data={"complete": False, "missing_tables": [], "existing_tables": []} + ) diff --git a/backend/app/api/v1/data_import.py b/backend/app/api/v1/data_import.py new file mode 100644 index 0000000..36b97fb --- /dev/null +++ b/backend/app/api/v1/data_import.py @@ -0,0 +1,457 @@ +""" +数据导入路由 +""" +import pandas as pd +import logging +from fastapi import APIRouter, Depends, UploadFile, File, HTTPException +from sqlalchemy.orm import Session +from datetime import datetime, date + +from app.db.session import get_db +from app.schemas.base import ResponseModel +from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade +from app.core.security import get_current_user +from app.models.user import User + +router = APIRouter() +logger = logging.getLogger(__name__) + +INDEX_TRADE_COLUMN_MAP = { + '证券代码': 'index_code', + '证券名称': 'name', + '成分个数 [交易日期]最新': 'component_count', + '开盘价 [交易日期]最新': 'open', + '收盘价 [交易日期]最新': 'close', + '成交量 [交易日期]最新 [单位]股': 'volume', + '成交额 [交易日期]最新 [单位]百万元': 'amount', + '总市值 [截止日期]最新 [单位]百万元': 'total_market_value', + '自由流通市值 [交易日期]最新 [单位]百万元': 'float_market_value', + '涨跌幅 [交易日期]最新 [单位]%': 'change_pct', + '最高价 [交易日期]最新': 'high', + '最低价 [交易日期]最新': 'low', + '上涨家数 [交易日期]最新': 'up_count', + '下跌家数 [交易日期]最新': 'down_count', + '平盘家数 [交易日期]最新': 'flat_count', + '涨停家数 [交易日期]最新': 'limit_up_count', + '跌停家数 [交易日期]最新': 'limit_down_count', + '停牌家数 [交易日期]最新': 'suspend_count', + '近期创历史新高 [交易日期]最新 [近N日内]300 [复权方式]不复权': 'is_new_high', + '近期创历史新低 [交易日期]最新 [近N日内]300 [复权方式]不复权': 'is_new_low', + '市盈率PE(TTM) [交易日期]最新 [剔除规则]不调整': 'pe_ratio', + '市盈率PE(TTM)中位值 [交易日期]最新 [剔除规则]不调整': 'pe_median' +} + + +@router.post("/index-data", response_model=ResponseModel) +async def import_index_data( + file: UploadFile = File(...), + trade_date: str = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """导入指数数据(同时更新指数基础表和指数交易表)""" + if not file.filename.endswith(('.xls', '.xlsx')): + raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") + + if not trade_date: + raise HTTPException(status_code=400, detail="请提供交易日期参数(YYYY-MM-DD格式)") + + try: + trade_date_obj = datetime.strptime(trade_date, '%Y-%m-%d').date() + except: + raise HTTPException(status_code=400, detail="交易日期格式错误,请使用YYYY-MM-DD格式") + + try: + df = pd.read_excel(file.file) + + df.columns = df.columns.str.strip() + + renamed_df = df.rename(columns=INDEX_TRADE_COLUMN_MAP) + + if 'index_code' not in renamed_df.columns: + raise HTTPException(status_code=400, detail="缺少必要列:证券代码") + + success_count = 0 + error_count = 0 + index_basic_updated = 0 + index_basic_added = 0 + + for _, row in renamed_df.iterrows(): + try: + index_code = str(row['index_code']).strip() + if not index_code: + continue + + name = str(row.get('name', '')) if pd.notna(row.get('name')) else None + component_count = int(row.get('component_count')) if pd.notna(row.get('component_count')) else None + + index_basic = db.query(IndexBasic).filter(IndexBasic.code == index_code).first() + + if index_basic: + if component_count and index_basic.component_count != component_count: + index_basic.component_count = component_count + index_basic.name = name or index_basic.name + index_basic.updated_at = datetime.utcnow() + index_basic_updated += 1 + else: + index_basic = IndexBasic( + code=index_code, + name=name, + component_count=component_count + ) + db.add(index_basic) + db.flush() + index_basic_added += 1 + + existing_trade = db.query(IndexTrade).filter( + IndexTrade.index_code == index_code, + IndexTrade.trade_date == trade_date_obj + ).first() + + def get_float_val(col_name): + val = row.get(col_name) + if pd.notna(val): + try: + return float(val) + except: + return None + return None + + def get_int_val(col_name): + val = row.get(col_name) + if pd.notna(val): + try: + return int(float(val)) + except: + return None + return None + + def get_bool_val(col_name): + val = row.get(col_name) + if pd.notna(val): + if isinstance(val, bool): + return val + if isinstance(val, str): + return val.lower() in ['true', '1', 'yes', '是'] + return bool(val) + return False + + open_price = get_float_val('open') + close_price = get_float_val('close') + high_price = get_float_val('high') + low_price = get_float_val('low') + change_pct = get_float_val('change_pct') + volume = get_int_val('volume') + amount = get_float_val('amount') + total_market_value = get_float_val('total_market_value') + float_market_value = get_float_val('float_market_value') + up_count = get_int_val('up_count') + down_count = get_int_val('down_count') + flat_count = get_int_val('flat_count') + limit_up_count = get_int_val('limit_up_count') + limit_down_count = get_int_val('limit_down_count') + suspend_count = get_int_val('suspend_count') + pe_ratio = get_float_val('pe_ratio') + pe_median = get_float_val('pe_median') + is_new_high = get_bool_val('is_new_high') + is_new_low = get_bool_val('is_new_low') + + if existing_trade: + existing_trade.open = open_price + existing_trade.close = close_price + existing_trade.high = high_price + existing_trade.low = low_price + existing_trade.change_pct = change_pct + existing_trade.volume = volume + existing_trade.amount = amount + existing_trade.total_market_value = total_market_value + existing_trade.float_market_value = float_market_value + existing_trade.up_count = up_count + existing_trade.down_count = down_count + existing_trade.flat_count = flat_count + existing_trade.limit_up_count = limit_up_count + existing_trade.limit_down_count = limit_down_count + existing_trade.suspend_count = suspend_count + existing_trade.pe_ratio = pe_ratio + existing_trade.pe_median = pe_median + existing_trade.is_new_high = is_new_high + existing_trade.is_new_low = is_new_low + existing_trade.updated_at = datetime.utcnow() + else: + trade = IndexTrade( + index_code=index_code, + trade_date=trade_date_obj, + open=open_price, + close=close_price, + high=high_price, + low=low_price, + change_pct=change_pct, + volume=volume, + amount=amount, + total_market_value=total_market_value, + float_market_value=float_market_value, + up_count=up_count, + down_count=down_count, + flat_count=flat_count, + limit_up_count=limit_up_count, + limit_down_count=limit_down_count, + suspend_count=suspend_count, + pe_ratio=pe_ratio, + pe_median=pe_median, + is_new_high=is_new_high, + is_new_low=is_new_low + ) + db.add(trade) + + success_count += 1 + + except Exception as e: + logger.error(f"导入指数{row.get('index_code')}失败: {str(e)}") + error_count += 1 + + db.commit() + + return ResponseModel(data={ + "success_count": success_count, + "error_count": error_count, + "total_count": len(df), + "index_basic_added": index_basic_added, + "index_basic_updated": index_basic_updated, + "trade_date": trade_date + }) + + except Exception as e: + logger.error(f"导入指数数据失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/stock-basic", response_model=ResponseModel) +async def import_stock_basic( + file: UploadFile = File(...), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """导入股票基础数据""" + if not file.filename.endswith(('.xls', '.xlsx')): + raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") + + try: + df = pd.read_excel(file.file) + + required_columns = ['code', 'name', 'total_shares', 'float_shares', + 'industry_index_name', 'industry_index_code', + 'institution_hold_ratio', 'industry_level3', 'list_date'] + + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}") + + success_count = 0 + error_count = 0 + + for _, row in df.iterrows(): + try: + existing = db.query(StockBasic).filter(StockBasic.code == str(row['code'])).first() + + list_date = None + if pd.notna(row['list_date']): + if isinstance(row['list_date'], datetime): + list_date = row['list_date'].date() + elif isinstance(row['list_date'], str): + list_date = datetime.strptime(row['list_date'], '%Y-%m-%d').date() + + if existing: + existing.name = str(row.get('name', existing.name)) + existing.total_shares = int(row.get('total_shares', existing.total_shares)) if pd.notna(row.get('total_shares')) else existing.total_shares + existing.float_shares = int(row.get('float_shares', existing.float_shares)) if pd.notna(row.get('float_shares')) else existing.float_shares + existing.industry_index_name = str(row.get('industry_index_name', existing.industry_index_name)) if pd.notna(row.get('industry_index_name')) else existing.industry_index_name + existing.industry_index_code = str(row.get('industry_index_code', existing.industry_index_code)) if pd.notna(row.get('industry_index_code')) else existing.industry_index_code + existing.institution_hold_ratio = float(row.get('institution_hold_ratio', existing.institution_hold_ratio)) if pd.notna(row.get('institution_hold_ratio')) else existing.institution_hold_ratio + existing.industry_level3 = str(row.get('industry_level3', existing.industry_level3)) if pd.notna(row.get('industry_level3')) else existing.industry_level3 + existing.list_date = list_date + existing.updated_at = datetime.utcnow() + else: + stock = StockBasic( + code=str(row['code']), + name=str(row.get('name', '')), + total_shares=int(row['total_shares']) if pd.notna(row['total_shares']) else None, + float_shares=int(row['float_shares']) if pd.notna(row['float_shares']) else None, + industry_index_name=str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None, + industry_index_code=str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None, + institution_hold_ratio=float(row['institution_hold_ratio']) if pd.notna(row['institution_hold_ratio']) else None, + industry_level3=str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None, + list_date=list_date + ) + db.add(stock) + + success_count += 1 + except Exception as e: + logger.error(f"导入股票{row.get('code')}失败: {str(e)}") + error_count += 1 + + db.commit() + + return ResponseModel(data={ + "success_count": success_count, + "error_count": error_count, + "total_count": len(df) + }) + + except Exception as e: + logger.error(f"导入股票基础数据失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/index-basic", response_model=ResponseModel) +async def import_index_basic( + file: UploadFile = File(...), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """导入指数基础数据""" + if not file.filename.endswith(('.xls', '.xlsx')): + raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") + + try: + df = pd.read_excel(file.file) + + required_columns = ['code', 'name', 'component_count'] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}") + + success_count = 0 + error_count = 0 + + for _, row in df.iterrows(): + try: + existing = db.query(IndexBasic).filter(IndexBasic.code == str(row['code'])).first() + + if existing: + existing.name = str(row.get('name', existing.name)) + existing.component_count = int(row.get('component_count', existing.component_count)) if pd.notna(row.get('component_count')) else existing.component_count + existing.updated_at = datetime.utcnow() + else: + index = IndexBasic( + code=str(row['code']), + name=str(row.get('name', '')), + component_count=int(row['component_count']) if pd.notna(row['component_count']) else None + ) + db.add(index) + + success_count += 1 + except Exception as e: + logger.error(f"导入指数{row.get('code')}失败: {str(e)}") + error_count += 1 + + db.commit() + + return ResponseModel(data={ + "success_count": success_count, + "error_count": error_count, + "total_count": len(df) + }) + + except Exception as e: + logger.error(f"导入指数基础数据失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/index-trade", response_model=ResponseModel) +async def import_index_trade( + file: UploadFile = File(...), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """导入指数交易数据""" + if not file.filename.endswith(('.xls', '.xlsx')): + raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") + + try: + df = pd.read_excel(file.file) + + required_columns = ['index_code', 'trade_date', 'open', 'close', 'high', 'low'] + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}") + + success_count = 0 + error_count = 0 + + for _, row in df.iterrows(): + try: + trade_date = None + if pd.notna(row['trade_date']): + if isinstance(row['trade_date'], datetime): + trade_date = row['trade_date'].date() + elif isinstance(row['trade_date'], str): + trade_date = datetime.strptime(row['trade_date'], '%Y-%m-%d').date() + + existing = db.query(IndexTrade).filter( + IndexTrade.index_code == str(row['index_code']), + IndexTrade.trade_date == trade_date + ).first() + + if existing: + existing.open = float(row.get('open', existing.open)) if pd.notna(row.get('open')) else existing.open + existing.close = float(row.get('close', existing.close)) if pd.notna(row.get('close')) else existing.close + existing.high = float(row.get('high', existing.high)) if pd.notna(row.get('high')) else existing.high + existing.low = float(row.get('low', existing.low)) if pd.notna(row.get('low')) else existing.low + existing.change_pct = float(row.get('change_pct', existing.change_pct)) if pd.notna(row.get('change_pct')) else existing.change_pct + existing.volume = int(row.get('volume', existing.volume)) if pd.notna(row.get('volume')) else existing.volume + existing.amount = float(row.get('amount', existing.amount)) if pd.notna(row.get('amount')) else existing.amount + existing.total_market_value = float(row.get('total_market_value', existing.total_market_value)) if pd.notna(row.get('total_market_value')) else existing.total_market_value + existing.float_market_value = float(row.get('float_market_value', existing.float_market_value)) if pd.notna(row.get('float_market_value')) else existing.float_market_value + existing.up_count = int(row.get('up_count', existing.up_count)) if pd.notna(row.get('up_count')) else existing.up_count + existing.down_count = int(row.get('down_count', existing.down_count)) if pd.notna(row.get('down_count')) else existing.down_count + existing.flat_count = int(row.get('flat_count', existing.flat_count)) if pd.notna(row.get('flat_count')) else existing.flat_count + existing.limit_up_count = int(row.get('limit_up_count', existing.limit_up_count)) if pd.notna(row.get('limit_up_count')) else existing.limit_up_count + existing.limit_down_count = int(row.get('limit_down_count', existing.limit_down_count)) if pd.notna(row.get('limit_down_count')) else existing.limit_down_count + existing.suspend_count = int(row.get('suspend_count', existing.suspend_count)) if pd.notna(row.get('suspend_count')) else existing.suspend_count + existing.pe_ratio = float(row.get('pe_ratio', existing.pe_ratio)) if pd.notna(row.get('pe_ratio')) else existing.pe_ratio + existing.pe_median = float(row.get('pe_median', existing.pe_median)) if pd.notna(row.get('pe_median')) else existing.pe_median + existing.is_new_high = bool(row.get('is_new_high', existing.is_new_high)) if pd.notna(row.get('is_new_high')) else existing.is_new_high + existing.is_new_low = bool(row.get('is_new_low', existing.is_new_low)) if pd.notna(row.get('is_new_low')) else existing.is_new_low + existing.updated_at = datetime.utcnow() + else: + trade = IndexTrade( + index_code=str(row['index_code']), + trade_date=trade_date, + open=float(row['open']) if pd.notna(row['open']) else None, + close=float(row['close']) if pd.notna(row['close']) else None, + high=float(row['high']) if pd.notna(row['high']) else None, + low=float(row['low']) if pd.notna(row['low']) else None, + change_pct=float(row.get('change_pct')) if pd.notna(row.get('change_pct')) else None, + volume=int(row.get('volume')) if pd.notna(row.get('volume')) else None, + amount=float(row.get('amount')) if pd.notna(row.get('amount')) else None, + total_market_value=float(row.get('total_market_value')) if pd.notna(row.get('total_market_value')) else None, + float_market_value=float(row.get('float_market_value')) if pd.notna(row.get('float_market_value')) else None, + up_count=int(row.get('up_count')) if pd.notna(row.get('up_count')) else None, + down_count=int(row.get('down_count')) if pd.notna(row.get('down_count')) else None, + flat_count=int(row.get('flat_count')) if pd.notna(row.get('flat_count')) else None, + limit_up_count=int(row.get('limit_up_count')) if pd.notna(row.get('limit_up_count')) else None, + limit_down_count=int(row.get('limit_down_count')) if pd.notna(row.get('limit_down_count')) else None, + suspend_count=int(row.get('suspend_count')) if pd.notna(row.get('suspend_count')) else None, + pe_ratio=float(row.get('pe_ratio')) if pd.notna(row.get('pe_ratio')) else None, + pe_median=float(row.get('pe_median')) if pd.notna(row.get('pe_median')) else None, + is_new_high=bool(row.get('is_new_high')) if pd.notna(row.get('is_new_high')) else False, + is_new_low=bool(row.get('is_new_low')) if pd.notna(row.get('is_new_low')) else False + ) + db.add(trade) + + success_count += 1 + except Exception as e: + logger.error(f"导入指数交易{row.get('index_code')}-{row.get('trade_date')}失败: {str(e)}") + error_count += 1 + + db.commit() + + return ResponseModel(data={ + "success_count": success_count, + "error_count": error_count, + "total_count": len(df) + }) + + except Exception as e: + logger.error(f"导入指数交易数据失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/backend/app/api/v1/index.py b/backend/app/api/v1/index.py new file mode 100644 index 0000000..8b346bc --- /dev/null +++ b/backend/app/api/v1/index.py @@ -0,0 +1,151 @@ +""" +指数数据查询路由 +""" +from typing import List +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session +from sqlalchemy import and_ +from datetime import date + +from app.db.session import get_db +from app.schemas.base import ResponseModel +from app.models.stock_basic import IndexBasic, IndexTrade +from app.core.security import get_current_user +from app.models.user import User +from app.utils.date_utils import parse_date, format_date + +router = APIRouter() + + +@router.get("/list", response_model=ResponseModel) +async def get_index_list( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取指数列表""" + indexes = db.query(IndexBasic).order_by(IndexBasic.code).all() + + result = [] + for idx in indexes: + result.append({ + "code": idx.code, + "name": idx.name, + "component_count": idx.component_count + }) + + return ResponseModel(data=result) + + +@router.get("/trade", response_model=ResponseModel) +async def get_index_trade_data( + codes: str = Query(..., description="指数代码列表,逗号分隔"), + start_date: str = Query(..., description="开始日期 YYYYMMDD"), + end_date: str = Query(..., description="结束日期 YYYYMMDD"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取指数交易数据""" + code_list = codes.split(",") + start = parse_date(start_date) + end = parse_date(end_date) + + result = {} + + for code in code_list: + code = code.strip() + + index_basic = db.query(IndexBasic).filter(IndexBasic.code == code).first() + + trades = db.query(IndexTrade).filter( + and_( + IndexTrade.index_code == code, + IndexTrade.trade_date >= start, + IndexTrade.trade_date <= end + ) + ).order_by(IndexTrade.trade_date).all() + + trade_list = [] + for trade in trades: + trade_list.append({ + "trade_date": format_date(trade.trade_date), + "open": float(trade.open) if trade.open else None, + "close": float(trade.close) if trade.close else None, + "high": float(trade.high) if trade.high else None, + "low": float(trade.low) if trade.low else None, + "change_pct": float(trade.change_pct) if trade.change_pct else None, + "volume": trade.volume, + "amount": float(trade.amount) if trade.amount else None, + "total_market_value": float(trade.total_market_value) if trade.total_market_value else None, + "float_market_value": float(trade.float_market_value) if trade.float_market_value else None, + "up_count": trade.up_count, + "down_count": trade.down_count, + "flat_count": trade.flat_count, + "limit_up_count": trade.limit_up_count, + "limit_down_count": trade.limit_down_count, + "suspend_count": trade.suspend_count, + "pe_ratio": float(trade.pe_ratio) if trade.pe_ratio else None, + "pe_median": float(trade.pe_median) if trade.pe_median else None, + "is_new_high": trade.is_new_high, + "is_new_low": trade.is_new_low + }) + + result[code] = { + "basic": { + "code": index_basic.code if index_basic else code, + "name": index_basic.name if index_basic else None, + "component_count": index_basic.component_count if index_basic else None + }, + "trades": trade_list + } + + return ResponseModel(data=result) + + +@router.get("/{code}/chart", response_model=ResponseModel) +async def get_index_chart_data( + code: str, + start_date: str = Query(..., description="开始日期 YYYYMMDD"), + end_date: str = Query(..., description="结束日期 YYYYMMDD"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取指数K线图表数据(ECharts格式)""" + start = parse_date(start_date) + end = parse_date(end_date) + + trades = db.query(IndexTrade).filter( + and_( + IndexTrade.index_code == code, + IndexTrade.trade_date >= start, + IndexTrade.trade_date <= end + ) + ).order_by(IndexTrade.trade_date).all() + + category_data = [] + values = [] + volumes = [] + + for trade in trades: + category_data.append(format_date(trade.trade_date)) + values.append([ + float(trade.open) if trade.open else 0, + float(trade.close) if trade.close else 0, + float(trade.low) if trade.low else 0, + float(trade.high) if trade.high else 0, + trade.volume if trade.volume else 0 + ]) + volumes.append([ + trade.volume if trade.volume else 0, + 1 if (trade.close and trade.open and trade.close >= trade.open) else -1 + ]) + + index_basic = db.query(IndexBasic).filter(IndexBasic.code == code).first() + + return ResponseModel(data={ + "code": code, + "name": index_basic.name if index_basic else None, + "component_count": index_basic.component_count if index_basic else None, + "categoryData": category_data, + "values": values, + "volumes": volumes + }) \ No newline at end of file diff --git a/backend/app/config.py b/backend/app/config.py index 543ff15..e13a1f6 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -5,6 +5,7 @@ import os from typing import Optional from pydantic_settings import BaseSettings from pydantic import Field +from sqlalchemy.orm import Session class Settings(BaseSettings): @@ -52,3 +53,19 @@ class Settings(BaseSettings): # 全局配置实例 settings = Settings() + + +def load_system_configs(db: Session): + """ + 从数据库加载系统配置到全局settings + 注意:这需要在数据库初始化后调用 + """ + from app.services.config_service import ConfigService + + system_configs = ConfigService.get_all_system_configs(db) + + if "DATABASE_URL" in system_configs: + settings.DATABASE_URL = system_configs["DATABASE_URL"] + + if "REDIS_URL" in system_configs: + settings.REDIS_URL = system_configs["REDIS_URL"] diff --git a/backend/app/db/session.py b/backend/app/db/session.py index 56481ec..91eb3ba 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -8,21 +8,37 @@ from typing import Generator from app.config import settings from app.db.base import Base +# 确保使用SQLite作为默认数据库 +database_url = settings.DATABASE_URL or "sqlite:///./amazing_data.db" + # 创建数据库引擎 -if settings.DATABASE_URL.startswith("sqlite"): +try: + if database_url.startswith("sqlite"): + engine = create_engine( + database_url, + connect_args={"check_same_thread": False}, + echo=settings.DEBUG + ) + else: + engine = create_engine( + database_url, + pool_pre_ping=True, + pool_size=10, + max_overflow=20, + echo=settings.DEBUG + ) + # 测试连接 + with engine.connect() as conn: + pass +except Exception as e: + print(f"数据库连接失败: {e}") + print("使用SQLite作为备选数据库...") + # 使用SQLite作为备选 engine = create_engine( - settings.DATABASE_URL, + "sqlite:///./amazing_data.db", connect_args={"check_same_thread": False}, echo=settings.DEBUG ) -else: - engine = create_engine( - settings.DATABASE_URL, - pool_pre_ping=True, - pool_size=10, - max_overflow=20, - echo=settings.DEBUG - ) # 创建会话工厂 SessionLocal = sessionmaker( diff --git a/backend/app/main.py b/backend/app/main.py index fa72160..c9ce090 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -30,6 +30,13 @@ async def lifespan(app: FastAPI): try: init_db() print("Database initialized successfully") + + # 加载系统配置 + from app.db.session import get_db + from app.config import load_system_configs + db = next(get_db()) + load_system_configs(db) + print("System configs loaded successfully") except Exception as e: print(f"Database initialization warning: {e}") diff --git a/backend/app/models/stock_basic.py b/backend/app/models/stock_basic.py new file mode 100644 index 0000000..26a6260 --- /dev/null +++ b/backend/app/models/stock_basic.py @@ -0,0 +1,78 @@ +""" +股票基础数据模型 +""" +from datetime import datetime, date +from sqlalchemy import Column, Integer, BigInteger, String, Numeric, Text, Date, DateTime, ForeignKey, Boolean +from sqlalchemy.orm import relationship +from app.db.base import Base + + +class StockBasic(Base): + """股票基础数据表""" + __tablename__ = "stock_basic" + + id = Column(BigInteger, primary_key=True, index=True) + code = Column(String(20), unique=True, nullable=False, index=True) + name = Column(String(50)) + total_shares = Column(BigInteger) + float_shares = Column(BigInteger) + industry_index_name = Column(String(100)) + industry_index_code = Column(String(20), ForeignKey("index_basic.code")) + institution_hold_ratio = Column(Numeric(10, 4)) + industry_level3 = Column(String(100)) + list_date = Column(Date) + created_at = Column(DateTime(timezone=True), default=datetime.utcnow) + updated_at = Column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow) + + industry_index = relationship("IndexBasic", back_populates="stocks") + + +class IndexBasic(Base): + """指数基础表""" + __tablename__ = "index_basic" + + id = Column(BigInteger, primary_key=True, index=True) + code = Column(String(20), unique=True, nullable=False, index=True) + name = Column(String(100)) + component_count = Column(Integer) + created_at = Column(DateTime(timezone=True), default=datetime.utcnow) + updated_at = Column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow) + + stocks = relationship("StockBasic", back_populates="industry_index") + trades = relationship("IndexTrade", back_populates="index") + + +class IndexTrade(Base): + """指数交易表""" + __tablename__ = "index_trade" + + id = Column(BigInteger, primary_key=True, index=True) + index_code = Column(String(20), ForeignKey("index_basic.code"), nullable=False, index=True) + trade_date = Column(Date, nullable=False, index=True) + open = Column(Numeric(10, 3)) + close = Column(Numeric(10, 3)) + high = Column(Numeric(10, 3)) + low = Column(Numeric(10, 3)) + change_pct = Column(Numeric(10, 4)) + volume = Column(BigInteger) + amount = Column(Numeric(18, 2)) + total_market_value = Column(Numeric(18, 2)) + float_market_value = Column(Numeric(18, 2)) + up_count = Column(Integer) + down_count = Column(Integer) + flat_count = Column(Integer) + limit_up_count = Column(Integer) + limit_down_count = Column(Integer) + suspend_count = Column(Integer) + pe_ratio = Column(Numeric(10, 4)) + pe_median = Column(Numeric(10, 4)) + is_new_high = Column(Boolean, default=False) + is_new_low = Column(Boolean, default=False) + created_at = Column(DateTime(timezone=True), default=datetime.utcnow) + updated_at = Column(DateTime(timezone=True), default=datetime.utcnow, onupdate=datetime.utcnow) + + index = relationship("IndexBasic", back_populates="trades") + + __table_args__ = ( + {'unique_constraint': None}, + ) \ No newline at end of file diff --git a/backend/app/services/config_service.py b/backend/app/services/config_service.py index 7aec6ec..fb566da 100644 --- a/backend/app/services/config_service.py +++ b/backend/app/services/config_service.py @@ -147,3 +147,9 @@ class ConfigService: db.add(config) db.commit() + + @staticmethod + def get_all_system_configs(db: Session) -> dict: + """获取所有系统配置""" + configs = db.query(SystemConfig).all() + return {config.config_key: config.config_value for config in configs} diff --git a/backend/create_stock_basic_tables.py b/backend/create_stock_basic_tables.py new file mode 100644 index 0000000..2440b69 --- /dev/null +++ b/backend/create_stock_basic_tables.py @@ -0,0 +1,96 @@ +""" +创建股票基础数据相关表 +""" +from sqlalchemy import text +from app.db.session import SessionLocal + +db = SessionLocal() + +try: + # 创建股票基础数据表 + db.execute(text(""" + CREATE TABLE IF NOT EXISTS stock_basic ( + id BIGSERIAL PRIMARY KEY, + code VARCHAR(20) UNIQUE NOT NULL, + name VARCHAR(50), + total_shares BIGINT, + float_shares BIGINT, + industry_index_name VARCHAR(100), + industry_index_code VARCHAR(20), + institution_hold_ratio DECIMAL(10, 4), + industry_level3 VARCHAR(100), + list_date DATE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """)) + + # 创建指数基础表 + db.execute(text(""" + CREATE TABLE IF NOT EXISTS index_basic ( + id BIGSERIAL PRIMARY KEY, + code VARCHAR(20) UNIQUE NOT NULL, + name VARCHAR(100), + component_count INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """)) + + # 创建指数交易表 + db.execute(text(""" + CREATE TABLE IF NOT EXISTS index_trade ( + id BIGSERIAL PRIMARY KEY, + index_code VARCHAR(20) NOT NULL, + trade_date DATE NOT NULL, + open DECIMAL(10, 3), + close DECIMAL(10, 3), + high DECIMAL(10, 3), + low DECIMAL(10, 3), + change_pct DECIMAL(10, 4), + volume BIGINT, + amount DECIMAL(18, 2), + total_market_value DECIMAL(18, 2), + float_market_value DECIMAL(18, 2), + up_count INTEGER, + down_count INTEGER, + flat_count INTEGER, + limit_up_count INTEGER, + limit_down_count INTEGER, + suspend_count INTEGER, + pe_ratio DECIMAL(10, 4), + pe_median DECIMAL(10, 4), + is_new_high BOOLEAN DEFAULT FALSE, + is_new_low BOOLEAN DEFAULT FALSE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(index_code, trade_date) + ) + """)) + + # 创建索引 + db.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_basic_code ON stock_basic(code)")) + db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_basic_code ON index_basic(code)")) + db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_code ON index_trade(index_code)")) + db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_date ON index_trade(trade_date)")) + + # 添加外键约束 + db.execute(text(""" + ALTER TABLE stock_basic + ADD CONSTRAINT fk_stock_basic_index_code + FOREIGN KEY (industry_index_code) REFERENCES index_basic(code) + """)) + + db.execute(text(""" + ALTER TABLE index_trade + ADD CONSTRAINT fk_index_trade_index_code + FOREIGN KEY (index_code) REFERENCES index_basic(code) + """)) + + db.commit() + print("表创建成功") +except Exception as e: + print(f"创建表失败: {str(e)}") + db.rollback() +finally: + db.close() \ No newline at end of file diff --git a/frontend/src/api/config.ts b/frontend/src/api/config.ts index 390b3df..bf7b5be 100644 --- a/frontend/src/api/config.ts +++ b/frontend/src/api/config.ts @@ -23,3 +23,23 @@ export const testSDKConfig = (id: number) => { export const setDefaultConfig = (id: number) => { return request.post(`/configs/sdk/${id}/set-default`) } + +export const getSystemConfigs = () => { + return request.get('/configs/system') +} + +export const updateSystemConfigs = (data: any) => { + return request.put('/configs/system', data) +} + +export const testSystemConnection = (data: any) => { + return request.post('/configs/system/test', data) +} + +export const initDatabase = () => { + return request.post('/configs/system/init') +} + +export const checkDatabaseStructure = () => { + return request.get('/configs/system/structure') +} diff --git a/frontend/src/api/dataImport.ts b/frontend/src/api/dataImport.ts new file mode 100644 index 0000000..b780251 --- /dev/null +++ b/frontend/src/api/dataImport.ts @@ -0,0 +1,41 @@ +import request from '@/utils/request' + +export const importStockBasic = (file: File) => { + const formData = new FormData() + formData.append('file', file) + return request.post('/import/stock-basic', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }) +} + +export const importIndexBasic = (file: File) => { + const formData = new FormData() + formData.append('file', file) + return request.post('/import/index-basic', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }) +} + +export const importIndexTrade = (file: File) => { + const formData = new FormData() + formData.append('file', file) + return request.post('/import/index-trade', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }) +} + +export const importIndexData = (file: File, tradeDate: string) => { + const formData = new FormData() + formData.append('file', file) + return request.post(`/import/index-data?trade_date=${tradeDate}`, formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }) +} \ No newline at end of file diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts new file mode 100644 index 0000000..e479634 --- /dev/null +++ b/frontend/src/api/index.ts @@ -0,0 +1,23 @@ +import request from '@/utils/request' + +export const getIndexList = () => { + return request.get('/index/list') +} + +export const getIndexTradeData = (params: { + codes: string + start_date: string + end_date: string +}) => { + return request.get('/index/trade', { params }) +} + +export const getIndexChartData = ( + code: string, + params: { + start_date: string + end_date: string + } +) => { + return request.get(`/index/${code}/chart`, { params }) +} \ No newline at end of file diff --git a/frontend/src/views/ConfigManager.vue b/frontend/src/views/ConfigManager.vue index 7eef9e1..8b7b57c 100644 --- a/frontend/src/views/ConfigManager.vue +++ b/frontend/src/views/ConfigManager.vue @@ -1,54 +1,128 @@