diff --git a/backend/app/api/v1/data_import.py b/backend/app/api/v1/data_import.py
index 36b97fb..ca0e2f5 100644
--- a/backend/app/api/v1/data_import.py
+++ b/backend/app/api/v1/data_import.py
@@ -41,6 +41,16 @@ INDEX_TRADE_COLUMN_MAP = {
'市盈率PE(TTM)中位值 [交易日期]最新 [剔除规则]不调整': 'pe_median'
}
+STOCK_BASIC_COLUMN_MAP = {
+ '证券代码': 'code',
+ '证券名称': 'name',
+ '首发上市日': 'list_date',
+ '所属东财行业指数名称\n[行业类别]2级': 'industry_index_name',
+ '所属东财行业指数代码\n[行业类别]2级': 'industry_index_code',
+ '机构持股比例合计\n[报告期]最新一期\n[单位]%\n[比例类型]占总股本比例': 'institution_hold_ratio',
+ '所属东财行业名称\n[行业类别]3级': 'industry_level3'
+}
+
@router.post("/index-data", response_model=ResponseModel)
async def import_index_data(
@@ -231,70 +241,152 @@ async def import_stock_basic(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
- """导入股票基础数据"""
+ """导入股票基础数据(支持模板格式)"""
if not file.filename.endswith(('.xls', '.xlsx')):
raise HTTPException(status_code=400, detail="只支持xls或xlsx格式文件")
try:
df = pd.read_excel(file.file)
- required_columns = ['code', 'name', 'total_shares', 'float_shares',
- 'industry_index_name', 'industry_index_code',
- 'institution_hold_ratio', 'industry_level3', 'list_date']
+ df.columns = df.columns.str.strip()
- missing_columns = [col for col in required_columns if col not in df.columns]
- if missing_columns:
- raise HTTPException(status_code=400, detail=f"缺少必要列: {missing_columns}")
+ renamed_df = df.rename(columns=STOCK_BASIC_COLUMN_MAP)
+
+ if 'code' not in renamed_df.columns:
+ raise HTTPException(status_code=400, detail="缺少必要列:证券代码")
+
+ if 'name' not in renamed_df.columns:
+ raise HTTPException(status_code=400, detail="缺少必要列:证券名称")
success_count = 0
error_count = 0
+ added_count = 0
+ updated_count = 0
+ skipped_count = 0
+ skipped_details = []
+ error_details = []
- for _, row in df.iterrows():
+ for _, row in renamed_df.iterrows():
try:
- existing = db.query(StockBasic).filter(StockBasic.code == str(row['code'])).first()
+ code_val = row.get('code')
+ if pd.isna(code_val):
+ continue
+ code = str(code_val).strip()
+ if not code or code.lower() == 'nan':
+ continue
+
+ existing = db.query(StockBasic).filter(StockBasic.code == code).first()
list_date = None
- if pd.notna(row['list_date']):
- if isinstance(row['list_date'], datetime):
- list_date = row['list_date'].date()
- elif isinstance(row['list_date'], str):
- list_date = datetime.strptime(row['list_date'], '%Y-%m-%d').date()
+ list_date_val = row.get('list_date')
+ if pd.notna(list_date_val):
+ if isinstance(list_date_val, datetime):
+ list_date = list_date_val.date()
+ elif isinstance(list_date_val, str):
+ try:
+ list_date = datetime.strptime(list_date_val, '%Y-%m-%d').date()
+ except:
+ pass
+ elif hasattr(list_date_val, 'date'):
+ list_date = list_date_val.date()
+
+ name = str(row.get('name', '')) if pd.notna(row.get('name')) else None
+ industry_index_name = str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None
+ industry_index_code = str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None
+
+ institution_hold_ratio = None
+ ratio_val = row.get('institution_hold_ratio')
+ if pd.notna(ratio_val):
+ try:
+ if str(ratio_val).strip() == '--':
+ institution_hold_ratio = 0.0
+ else:
+ institution_hold_ratio = float(ratio_val)
+ except:
+ institution_hold_ratio = 0.0
+
+ industry_level3 = str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None
+
+ if industry_index_code:
+ index_basic = db.query(IndexBasic).filter(IndexBasic.code == industry_index_code).first()
+ if not index_basic:
+ index_basic = IndexBasic(
+ code=industry_index_code,
+ name=industry_index_name or industry_index_code
+ )
+ db.add(index_basic)
+ db.flush()
if existing:
- existing.name = str(row.get('name', existing.name))
- existing.total_shares = int(row.get('total_shares', existing.total_shares)) if pd.notna(row.get('total_shares')) else existing.total_shares
- existing.float_shares = int(row.get('float_shares', existing.float_shares)) if pd.notna(row.get('float_shares')) else existing.float_shares
- existing.industry_index_name = str(row.get('industry_index_name', existing.industry_index_name)) if pd.notna(row.get('industry_index_name')) else existing.industry_index_name
- existing.industry_index_code = str(row.get('industry_index_code', existing.industry_index_code)) if pd.notna(row.get('industry_index_code')) else existing.industry_index_code
- existing.institution_hold_ratio = float(row.get('institution_hold_ratio', existing.institution_hold_ratio)) if pd.notna(row.get('institution_hold_ratio')) else existing.institution_hold_ratio
- existing.industry_level3 = str(row.get('industry_level3', existing.industry_level3)) if pd.notna(row.get('industry_level3')) else existing.industry_level3
- existing.list_date = list_date
- existing.updated_at = datetime.utcnow()
+ def is_same_data():
+ def compare_ratio():
+ if existing.institution_hold_ratio is None and institution_hold_ratio is None:
+ return True
+ if existing.institution_hold_ratio is None or institution_hold_ratio is None:
+ return False
+ return abs(float(existing.institution_hold_ratio) - institution_hold_ratio) < 0.0001
+
+ return (
+ (existing.name == name or (existing.name is None and name is None)) and
+ (existing.industry_index_name == industry_index_name or (existing.industry_index_name is None and industry_index_name is None)) and
+ (existing.industry_index_code == industry_index_code or (existing.industry_index_code is None and industry_index_code is None)) and
+ compare_ratio() and
+ (existing.industry_level3 == industry_level3 or (existing.industry_level3 is None and industry_level3 is None)) and
+ (existing.list_date == list_date or (existing.list_date is None and list_date is None))
+ )
+
+ if is_same_data():
+ skipped_count += 1
+ skipped_details.append({
+ "code": code,
+ "name": name,
+ "reason": "数据相同,无需更新"
+ })
+ else:
+ existing.name = name
+ existing.industry_index_name = industry_index_name
+ existing.industry_index_code = industry_index_code
+ existing.institution_hold_ratio = institution_hold_ratio
+ existing.industry_level3 = industry_level3
+ existing.list_date = list_date
+ existing.updated_at = datetime.utcnow()
+ updated_count += 1
else:
stock = StockBasic(
- code=str(row['code']),
- name=str(row.get('name', '')),
- total_shares=int(row['total_shares']) if pd.notna(row['total_shares']) else None,
- float_shares=int(row['float_shares']) if pd.notna(row['float_shares']) else None,
- industry_index_name=str(row.get('industry_index_name', '')) if pd.notna(row.get('industry_index_name')) else None,
- industry_index_code=str(row.get('industry_index_code', '')) if pd.notna(row.get('industry_index_code')) else None,
- institution_hold_ratio=float(row['institution_hold_ratio']) if pd.notna(row['institution_hold_ratio']) else None,
- industry_level3=str(row.get('industry_level3', '')) if pd.notna(row.get('industry_level3')) else None,
+ code=code,
+ name=name,
+ total_shares=None,
+ float_shares=None,
+ industry_index_name=industry_index_name,
+ industry_index_code=industry_index_code,
+ institution_hold_ratio=institution_hold_ratio,
+ industry_level3=industry_level3,
list_date=list_date
)
db.add(stock)
+ added_count += 1
success_count += 1
except Exception as e:
logger.error(f"导入股票{row.get('code')}失败: {str(e)}")
error_count += 1
+ error_details.append({
+ "code": str(row.get('code', '')),
+ "name": str(row.get('name', '')) if pd.notna(row.get('name')) else '',
+ "reason": str(e)
+ })
db.commit()
return ResponseModel(data={
"success_count": success_count,
"error_count": error_count,
- "total_count": len(df)
+ "total_count": len(df),
+ "added_count": added_count,
+ "updated_count": updated_count,
+ "skipped_count": skipped_count,
+ "skipped_details": skipped_details[:100],
+ "error_details": error_details[:100]
})
except Exception as e:
diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py
index cf59865..69f8949 100644
--- a/backend/app/models/__init__.py
+++ b/backend/app/models/__init__.py
@@ -7,6 +7,7 @@ from app.models.realtime import RealtimeSnapshot
from app.models.finance import FinanceBalanceSheet, FinanceCashFlow, FinanceIncome
from app.models.cache import CacheTask, CacheTaskDetail
from app.models.test import APITestLog
+from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade
__all__ = [
"User",
@@ -25,4 +26,7 @@ __all__ = [
"CacheTask",
"CacheTaskDetail",
"APITestLog",
+ "StockBasic",
+ "IndexBasic",
+ "IndexTrade",
]
diff --git a/backend/create_stock_basic_tables.py b/backend/create_stock_basic_tables.py
index 2440b69..1593930 100644
--- a/backend/create_stock_basic_tables.py
+++ b/backend/create_stock_basic_tables.py
@@ -1,96 +1,18 @@
"""
-创建股票基础数据相关表
+创建股票基础数据相关表(使用 SQLAlchemy ORM,兼容 SQLite 和 PostgreSQL)
"""
-from sqlalchemy import text
-from app.db.session import SessionLocal
+from app.db.session import engine, SessionLocal
+from app.db.base import Base
+from app.models.stock_basic import StockBasic, IndexBasic, IndexTrade
-db = SessionLocal()
-
-try:
- # 创建股票基础数据表
- db.execute(text("""
- CREATE TABLE IF NOT EXISTS stock_basic (
- id BIGSERIAL PRIMARY KEY,
- code VARCHAR(20) UNIQUE NOT NULL,
- name VARCHAR(50),
- total_shares BIGINT,
- float_shares BIGINT,
- industry_index_name VARCHAR(100),
- industry_index_code VARCHAR(20),
- institution_hold_ratio DECIMAL(10, 4),
- industry_level3 VARCHAR(100),
- list_date DATE,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """))
-
- # 创建指数基础表
- db.execute(text("""
- CREATE TABLE IF NOT EXISTS index_basic (
- id BIGSERIAL PRIMARY KEY,
- code VARCHAR(20) UNIQUE NOT NULL,
- name VARCHAR(100),
- component_count INTEGER,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """))
-
- # 创建指数交易表
- db.execute(text("""
- CREATE TABLE IF NOT EXISTS index_trade (
- id BIGSERIAL PRIMARY KEY,
- index_code VARCHAR(20) NOT NULL,
- trade_date DATE NOT NULL,
- open DECIMAL(10, 3),
- close DECIMAL(10, 3),
- high DECIMAL(10, 3),
- low DECIMAL(10, 3),
- change_pct DECIMAL(10, 4),
- volume BIGINT,
- amount DECIMAL(18, 2),
- total_market_value DECIMAL(18, 2),
- float_market_value DECIMAL(18, 2),
- up_count INTEGER,
- down_count INTEGER,
- flat_count INTEGER,
- limit_up_count INTEGER,
- limit_down_count INTEGER,
- suspend_count INTEGER,
- pe_ratio DECIMAL(10, 4),
- pe_median DECIMAL(10, 4),
- is_new_high BOOLEAN DEFAULT FALSE,
- is_new_low BOOLEAN DEFAULT FALSE,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- UNIQUE(index_code, trade_date)
- )
- """))
+def create_tables():
+ """创建所有股票基础数据相关表"""
+ print("开始创建数据库表...")
- # 创建索引
- db.execute(text("CREATE INDEX IF NOT EXISTS idx_stock_basic_code ON stock_basic(code)"))
- db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_basic_code ON index_basic(code)"))
- db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_code ON index_trade(index_code)"))
- db.execute(text("CREATE INDEX IF NOT EXISTS idx_index_trade_date ON index_trade(trade_date)"))
+ Base.metadata.create_all(bind=engine)
- # 添加外键约束
- db.execute(text("""
- ALTER TABLE stock_basic
- ADD CONSTRAINT fk_stock_basic_index_code
- FOREIGN KEY (industry_index_code) REFERENCES index_basic(code)
- """))
-
- db.execute(text("""
- ALTER TABLE index_trade
- ADD CONSTRAINT fk_index_trade_index_code
- FOREIGN KEY (index_code) REFERENCES index_basic(code)
- """))
-
- db.commit()
- print("表创建成功")
-except Exception as e:
- print(f"创建表失败: {str(e)}")
- db.rollback()
-finally:
- db.close()
\ No newline at end of file
+ print("数据库表创建完成")
+ print(f"创建的表: stock_basic, index_basic, index_trade")
+
+if __name__ == "__main__":
+ create_tables()
\ No newline at end of file
diff --git a/frontend/src/views/DataImport/index.vue b/frontend/src/views/DataImport/index.vue
index 0a8f38f..1018177 100644
--- a/frontend/src/views/DataImport/index.vue
+++ b/frontend/src/views/DataImport/index.vue
@@ -58,9 +58,9 @@
+ 证券代码、证券名称、首发上市日、所属东财行业指数名称[行业类别]2级、
+ 所属东财行业指数代码[行业类别]2级、机构持股比例合计、所属东财行业名称[行业类别]3级