diff --git a/backend/app/api/v1/data_import.py b/backend/app/api/v1/data_import.py index 36b97fb..ca0e2f5 100644 --- a/backend/app/api/v1/data_import.py +++ b/backend/app/api/v1/data_import.py @@ -41,6 +41,16 @@ INDEX_TRADE_COLUMN_MAP = { '市盈率PE(TTM)中位值 [交易日期]最新 [剔除规则]不调整': 'pe_median' } +STOCK_BASIC_COLUMN_MAP = { + '证券代码': 'code', + '证券名称': 'name', + '首发上市日': 'list_date', + '所属东财行业指数名称\n[行业类别]2级': 'industry_index_name', + '所属东财行业指数代码\n[行业类别]2级': 'industry_index_code', + '机构持股比例合计\n[报告期]最新一期\n[单位]%\n[比例类型]占总股本比例': 'institution_hold_ratio', + '所属东财行业名称\n[行业类别]3级': 'industry_level3' +} + @router.post("/index-data", response_model=ResponseModel) async def import_index_data( @@ -231,70 +241,152 @@ async def import_stock_basic( db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): - """导入股票基础数据""" + """导入股票基础数据(支持模板格式)""" if not file.filename.endswith(('.xls', '.xlsx')): raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件") try: df = pd.read_excel(file.file) - required_columns = ['code', 'name', 'total_shares', 'float_shares', - 'industry_index_name', 'industry_index_code', - 'institution_hold_ratio', 'industry_level3', 'list_date'] + df.columns = df.columns.str.strip() - missing_columns = [col for col in required_columns if col not in df.columns] - if missing_columns: - raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}") + renamed_df = df.rename(columns=STOCK_BASIC_COLUMN_MAP) + + if 'code' not in renamed_df.columns: + raise HTTPException(status_code=400, detail="缺少必要列:证券代码") + + if 'name' not in renamed_df.columns: + raise HTTPException(status_code=400, detail="缺少必要列:证券名称") success_count = 0 error_count = 0 + added_count = 0 + updated_count = 0 + skipped_count = 0 + skipped_details = [] + error_details = [] - for _, row in df.iterrows(): + for _, row in renamed_df.iterrows(): try: - existing = db.query(StockBasic).filter(StockBasic.code == str(row['code'])).first() + code_val = row.get('code') + if pd.isna(code_val): + continue + code = str(code_val).strip() + if not code or code.lower() == 'nan': + continue + + existing = db.query(StockBasic).filter(StockBasic.code == code).first() list_date = None - if pd.notna(row['list_date']): - if isinstance(row['list_date'], datetime): - list_date = row['list_date'].date() - elif isinstance(row['list_date'], str): - list_date = datetime.strptime(row['list_date'], '%Y-%m-%d').date() + list_date_val = row.get('list_date') + if pd.notna(list_date_val): + if isinstance(list_date_val, datetime): + list_date = list_date_val.date() + elif isinstance(list_date_val, str): + try: + list_date = datetime.strptime(list_date_val, '%Y-%m-%d').date() + except: + pass + elif hasattr(list_date_val, 'date'): + list_date = list_date_val.date() + + name = str(row.get('name', '')) if pd.notna(row.get('name')) else None + industry_index_name = str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None + industry_index_code = str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None + + institution_hold_ratio = None + ratio_val = row.get('institution_hold_ratio') + if pd.notna(ratio_val): + try: + if str(ratio_val).strip() == '--': + institution_hold_ratio = 0.0 + else: + institution_hold_ratio = float(ratio_val) + except: + institution_hold_ratio = 0.0 + + industry_level3 = str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None + + if industry_index_code: + index_basic = db.query(IndexBasic).filter(IndexBasic.code == industry_index_code).first() + if not index_basic: + index_basic = IndexBasic( + code=industry_index_code, + name=industry_index_name or industry_index_code + ) + db.add(index_basic) + db.flush() if existing: - existing.name = str(row.get('name', existing.name)) - existing.total_shares = int(row.get('total_shares', existing.total_shares)) if pd.notna(row.get('total_shares')) else existing.total_shares - existing.float_shares = int(row.get('float_shares', existing.float_shares)) if pd.notna(row.get('float_shares')) else existing.float_shares - existing.industry_index_name = str(row.get('industry_index_name', existing.industry_index_name)) if pd.notna(row.get('industry_index_name')) else existing.industry_index_name - existing.industry_index_code = str(row.get('industry_index_code', existing.industry_index_code)) if pd.notna(row.get('industry_index_code')) else existing.industry_index_code - existing.institution_hold_ratio = float(row.get('institution_hold_ratio', existing.institution_hold_ratio)) if pd.notna(row.get('institution_hold_ratio')) else existing.institution_hold_ratio - existing.industry_level3 = str(row.get('industry_level3', existing.industry_level3)) if pd.notna(row.get('industry_level3')) else existing.industry_level3 - existing.list_date = list_date - existing.updated_at = datetime.utcnow() + def is_same_data(): + def compare_ratio(): + if existing.institution_hold_ratio is None and institution_hold_ratio is None: + return True + if existing.institution_hold_ratio is None or institution_hold_ratio is None: + return False + return abs(float(existing.institution_hold_ratio) - institution_hold_ratio) < 0.0001 + + return ( + (existing.name == name or (existing.name is None and name is None)) and + (existing.industry_index_name == industry_index_name or (existing.industry_index_name is None and industry_index_name is None)) and + (existing.industry_index_code == industry_index_code or (existing.industry_index_code is None and industry_index_code is None)) and + compare_ratio() and + (existing.industry_level3 == industry_level3 or (existing.industry_level3 is None and industry_level3 is None)) and + (existing.list_date == list_date or (existing.list_date is None and list_date is None)) + ) + + if is_same_data(): + skipped_count += 1 + skipped_details.append({ + "code": code, + "name": name, + "reason": "数据相同,无需更新" + }) + else: + existing.name = name + existing.industry_index_name = industry_index_name + existing.industry_index_code = industry_index_code + existing.institution_hold_ratio = institution_hold_ratio + existing.industry_level3 = industry_level3 + existing.list_date = list_date + existing.updated_at = datetime.utcnow() + updated_count += 1 else: stock = StockBasic( - code=str(row['code']), - name=str(row.get('name', '')), - total_shares=int(row['total_shares']) if pd.notna(row['total_shares']) else None, - float_shares=int(row['float_shares']) if pd.notna(row['float_shares']) else None, - industry_index_name=str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None, - industry_index_code=str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None, - institution_hold_ratio=float(row['institution_hold_ratio']) if pd.notna(row['institution_hold_ratio']) else None, - industry_level3=str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None, + code=code, + name=name, + total_shares=None, + float_shares=None, + industry_index_name=industry_index_name, + industry_index_code=industry_index_code, + institution_hold_ratio=institution_hold_ratio, + industry_level3=industry_level3, list_date=list_date ) db.add(stock) + added_count += 1 success_count += 1 except Exception as e: logger.error(f"导入股票{row.get('code')}失败: {str(e)}") error_count += 1 + error_details.append({ + "code": str(row.get('code', '')), + "name": str(row.get('name', '')) if pd.notna(row.get('name')) else '', + "reason": str(e) + }) db.commit() return ResponseModel(data={ "success_count": success_count, "error_count": error_count, - "total_count": len(df) + "total_count": len(df), + "added_count": added_count, + "updated_count": updated_count, + "skipped_count": skipped_count, + "skipped_details": skipped_details[:100], + "error_details": error_details[:100] }) except Exception as e: diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index cf59865..69f8949 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -7,6 +7,7 @@ from app.models.realtime import RealtimeSnapshot from app.models.finance import FinanceBalanceSheet, FinanceCashFlow, FinanceIncome from app.models.cache import CacheTask, CacheTaskDetail from app.models.test import APITestLog +from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade __all__ = [ "User", @@ -25,4 +26,7 @@ __all__ = [ "CacheTask", "CacheTaskDetail", "APITestLog", + "StockBasic", + "IndexBasic", + "IndexTrade", ] diff --git a/backend/create_stock_basic_tables.py b/backend/create_stock_basic_tables.py index 2440b69..1593930 100644 --- a/backend/create_stock_basic_tables.py +++ b/backend/create_stock_basic_tables.py @@ -1,96 +1,18 @@ """ -创建股票基础数据相关表 +创建股票基础数据相关表(使用 SQLAlchemy ORM,兼容 SQLite 和 PostgreSQL) """ -from sqlalchemy import text -from app.db.session import SessionLocal +from app.db.session import engine, SessionLocal +from app.db.base import Base +from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade -db = SessionLocal() - -try: - # 创建股票基础数据表 - db.execute(text(""" - CREATE TABLE IF NOT EXISTS stock_basic ( - id BIGSERIAL PRIMARY KEY, - code VARCHAR(20) UNIQUE NOT NULL, - name VARCHAR(50), - total_shares BIGINT, - float_shares BIGINT, - industry_index_name VARCHAR(100), - industry_index_code VARCHAR(20), - institution_hold_ratio DECIMAL(10, 4), - industry_level3 VARCHAR(100), - list_date DATE, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """)) - - # 创建指数基础表 - db.execute(text(""" - CREATE TABLE IF NOT EXISTS index_basic ( - id BIGSERIAL PRIMARY KEY, - code VARCHAR(20) UNIQUE NOT NULL, - name VARCHAR(100), - component_count INTEGER, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """)) - - # 创建指数交易表 - db.execute(text(""" - CREATE TABLE IF NOT EXISTS index_trade ( - id BIGSERIAL PRIMARY KEY, - index_code VARCHAR(20) NOT NULL, - trade_date DATE NOT NULL, - open DECIMAL(10, 3), - close DECIMAL(10, 3), - high DECIMAL(10, 3), - low DECIMAL(10, 3), - change_pct DECIMAL(10, 4), - volume BIGINT, - amount DECIMAL(18, 2), - total_market_value DECIMAL(18, 2), - float_market_value DECIMAL(18, 2), - up_count INTEGER, - down_count INTEGER, - flat_count INTEGER, - limit_up_count INTEGER, - limit_down_count INTEGER, - suspend_count INTEGER, - pe_ratio DECIMAL(10, 4), - pe_median DECIMAL(10, 4), - is_new_high BOOLEAN DEFAULT FALSE, - is_new_low BOOLEAN DEFAULT FALSE, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - UNIQUE(index_code, trade_date) - ) - """)) +def create_tables(): + """创建所有股票基础数据相关表""" + print("开始创建数据库表...") - # 创建索引 - db.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_basic_code ON stock_basic(code)")) - db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_basic_code ON index_basic(code)")) - db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_code ON index_trade(index_code)")) - db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_date ON index_trade(trade_date)")) + Base.metadata.create_all(bind=engine) - # 添加外键约束 - db.execute(text(""" - ALTER TABLE stock_basic - ADD CONSTRAINT fk_stock_basic_index_code - FOREIGN KEY (industry_index_code) REFERENCES index_basic(code) - """)) - - db.execute(text(""" - ALTER TABLE index_trade - ADD CONSTRAINT fk_index_trade_index_code - FOREIGN KEY (index_code) REFERENCES index_basic(code) - """)) - - db.commit() - print("表创建成功") -except Exception as e: - print(f"创建表失败: {str(e)}") - db.rollback() -finally: - db.close() \ No newline at end of file + print("数据库表创建完成") + print(f"创建的表: stock_basic, index_basic, index_trade") + +if __name__ == "__main__": + create_tables() \ No newline at end of file diff --git a/frontend/src/views/DataImport/index.vue b/frontend/src/views/DataImport/index.vue index 0a8f38f..1018177 100644 --- a/frontend/src/views/DataImport/index.vue +++ b/frontend/src/views/DataImport/index.vue @@ -58,9 +58,9 @@ - 必须包含以下列:code, name, total_shares, float_shares, - industry_index_name, industry_index_code, institution_hold_ratio, - industry_level3, list_date + 支持模板格式(stock_info_template.xlsx),自动适配以下列:
+ 证券代码、证券名称、首发上市日、所属东财行业指数名称[行业类别]2级、
+ 所属东财行业指数代码[行业类别]2级、机构持股比例合计、所属东财行业名称[行业类别]3级
@@ -83,7 +83,18 @@
- +
+ + + 查看详情 + +
@@ -150,6 +161,27 @@ + + + + + + + + + +
共 {{ stockBasicDetailData.skipped_details?.length || 0 }} 条跳过数据
+
+ + + + + + +
共 {{ stockBasicDetailData.error_details?.length || 0 }} 条失败数据
+
+
+
@@ -175,6 +207,10 @@ const indexBasicResult = ref(null) const indexTradeResult = ref(null) const indexDataResult = ref(null) +const stockBasicDetailData = ref({}) +const detailDialogVisible = ref(false) +const detailTab = ref('skipped') + const handleStockBasicChange = (file: any) => { stockBasicFile.value = file.raw } @@ -238,15 +274,20 @@ const handleImportStockBasic = async () => { const res: any = await importStockBasic(stockBasicFile.value) if (res.data) { stockBasicResult.value = { - title: `导入完成:成功${res.data.success_count}条,失败${res.data.error_count}条,共${res.data.total_count}条`, - type: res.data.error_count > 0 ? 'warning' : 'success' + title: `导入完成:新增${res.data.added_count}条,更新${res.data.updated_count}条,跳过${res.data.skipped_count}条,失败${res.data.error_count}条`, + type: res.data.error_count > 0 ? 'warning' : 'success', + skipped_count: res.data.skipped_count, + error_count: res.data.error_count } + stockBasicDetailData.value = res.data ElMessage.success('导入完成') } } catch (error: any) { stockBasicResult.value = { title: `导入失败:${error.response?.data?.detail || error.message}`, - type: 'error' + type: 'error', + skipped_count: 0, + error_count: 0 } ElMessage.error('导入失败') } finally { @@ -254,6 +295,11 @@ const handleImportStockBasic = async () => { } } +const showStockBasicDetail = () => { + detailTab.value = stockBasicDetailData.value.error_count > 0 ? 'error' : 'skipped' + detailDialogVisible.value = true +} + const handleImportIndexBasic = async () => { if (!indexBasicFile.value) { ElMessage.warning('请先选择文件') @@ -317,4 +363,14 @@ const handleImportIndexTrade = async () => { .data-import { padding: 20px; } +.result-area { + display: flex; + align-items: center; + margin-top: 10px; +} +.detail-count { + margin-top: 10px; + color: #666; + font-size: 14px; +} \ No newline at end of file diff --git a/template/stock_info_template.xlsx b/template/stock_info_template.xlsx new file mode 100644 index 0000000..649d5ae Binary files /dev/null and b/template/stock_info_template.xlsx differ diff --git a/template/全部A股_20260417.xlsx b/template/全部A股_20260417.xlsx new file mode 100644 index 0000000..14cd581 Binary files /dev/null and b/template/全部A股_20260417.xlsx differ