You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

439 lines
15 KiB

from fastapi import FastAPI, Depends, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from datetime import datetime
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
from app.config import settings
from app.database import get_db, engine, Base
from app.schemas import (
KlineRequest, KlineResponse, KlineItem,
ContractInfo as ContractSchema, ContractListResponse,
DataSourceConfigItem, DataSourceConfigUpdate, DataSourceCreate,
ApiResponse, HealthResponse, DataSourceStatus,
BatchSyncRequest, BatchSyncResult,
ProductInfo as ProductSchema, ProductTreeResponse,
)
from app.services.kline_service import kline_service
from app.services.contract_service import contract_service
from app.services.product_service import product_service
from app.services.datasource.manager import DataSourceManager
from app.models import DataSourceConfig
logger = logging.getLogger(__name__)
app = FastAPI(
title=settings.PROJECT_NAME,
version=settings.VERSION,
docs_url="/docs",
redoc_url="/redoc",
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== 启动事件 ==========
@app.on_event("startup")
async def startup():
# 创建数据库表
Base.metadata.create_all(bind=engine)
# 加载数据源配置
DataSourceManager.load_enabled_sources()
# 初始化默认数据源配置(如果不存在)
_init_default_datasource()
def _init_default_datasource():
"""初始化默认的数据源配置(如果不存在)"""
from app.database import SessionLocal
db = SessionLocal()
try:
# 初始化 Tushare
existing = db.query(DataSourceConfig).filter(
DataSourceConfig.source_name == "tushare"
).first()
if not existing:
import json
cfg = DataSourceConfig(
source_name="tushare",
display_name="Tushare",
is_enabled=False,
config_json=json.dumps({"token": ""}),
priority=1,
status="unknown",
)
db.add(cfg)
# 初始化 Akshare
existing_ak = db.query(DataSourceConfig).filter(
DataSourceConfig.source_name == "akshare"
).first()
if not existing_ak:
import json
cfg_ak = DataSourceConfig(
source_name="akshare",
display_name="AKShare",
is_enabled=False,
config_json=json.dumps({"max_retries": 3}),
priority=2,
status="unknown",
)
db.add(cfg_ak)
db.commit()
finally:
db.close()
# ========== 健康检查 ==========
@app.get("/api/health", response_model=HealthResponse)
async def health_check():
services = {}
try:
from sqlalchemy import text
from app.database import engine
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
services["database"] = "ok"
except Exception as e:
services["database"] = f"error: {str(e)}"
try:
import redis
r = redis.from_url(settings.REDIS_URL)
r.ping()
services["redis"] = "ok"
except Exception as e:
services["redis"] = "not configured" # Redis 非必须
status = "healthy" if all(v == "ok" for v in services.values()) else "degraded"
return HealthResponse(
status=status,
services=services,
version=settings.VERSION,
)
# ========== 品种接口 ==========
@app.get("/api/v1/products")
async def list_products(
exchange: Optional[str] = Query(None, description="交易所代码"),
category: Optional[str] = Query(None, description="品种分类"),
is_active: Optional[bool] = Query(None, description="是否活跃"),
):
"""获取品种列表"""
products = product_service.get_products(
exchange=exchange, category=category, is_active=is_active
)
return {"code": 0, "data": products}
@app.get("/api/v1/products/tree")
async def get_product_tree():
"""获取品种树结构"""
tree = product_service.get_product_tree()
return {"code": 0, "data": {"categories": tree}}
@app.get("/api/v1/products/{product_code}/contracts")
async def get_product_contracts(
product_code: str,
is_active: Optional[bool] = Query(None, description="是否活跃"),
):
"""获取指定品种的所有合约"""
contracts = product_service.get_product_contracts(
product_code=product_code, is_active=is_active
)
return {"code": 0, "data": contracts}
@app.post("/api/v1/contracts/{symbol}/set-main")
async def set_main_contract(symbol: str):
"""设置主力合约"""
success = product_service.set_main_contract(symbol)
if success:
return {"code": 0, "message": "设置成功"}
return {"code": 1, "message": "设置失败,合约不存在"}
@app.post("/api/v1/contracts/update-main")
async def update_main_contracts():
"""根据持仓量自动更新主力合约标识"""
count = product_service.update_main_contracts()
return {"code": 0, "message": f"更新了 {count} 个主力合约"}
# ========== 合约接口 ==========
@app.get("/api/v1/contracts", response_model=ContractListResponse)
async def list_contracts(
exchange: Optional[str] = Query(None, description="交易所代码"),
product: Optional[str] = Query(None, description="品种代码"),
is_active: Optional[bool] = Query(None, description="是否活跃"),
):
contracts = contract_service.get_contracts(
exchange=exchange, product=product, is_active=is_active
)
return ContractListResponse(
total=len(contracts),
items=[ContractSchema.model_validate(c) for c in contracts],
)
@app.get("/api/v1/contracts/products", response_model=ApiResponse)
async def list_products(
exchange: Optional[str] = Query(None, description="交易所代码"),
):
"""获取品种列表(去重后的品种信息)"""
logger.info(f"[API-获取品种列表] exchange={exchange}")
products = contract_service.get_products(exchange=exchange)
return {"code": 0, "message": "ok", "data": {"items": products, "total": len(products)}}
@app.get("/api/v1/contracts/by-month", response_model=ContractListResponse)
async def get_contracts_by_month(
product: str = Query(..., description="品种代码"),
start_month: str = Query(..., description="起始月份 YYYY-MM 或 YYYYMM"),
limit: int = Query(5, ge=1, le=20, description="返回合约数量"),
):
"""根据品种和起始月份查询合约列表"""
logger.info(f"[API-按月份查询合约] product={product}, start_month={start_month}, limit={limit}")
contracts = contract_service.get_contracts_by_month(
product=product,
start_month=start_month,
limit=limit
)
return ContractListResponse(
total=len(contracts),
items=[ContractSchema.model_validate(c) for c in contracts],
)
@app.get("/api/v1/contracts/{symbol}", response_model=ContractSchema)
async def get_contract(symbol: str):
contract = contract_service.get_contract(symbol)
if not contract:
raise HTTPException(status_code=404, detail="合约不存在")
return ContractSchema.model_validate(contract)
@app.post("/api/v1/contracts/sync")
async def sync_contracts():
"""从数据源同步合约列表"""
try:
count = contract_service.sync_contracts()
return {"code": 0, "message": "同步成功", "data": {"synced": count}}
except Exception as e:
return {"code": 1, "message": f"同步失败: {str(e)}", "data": None}
# ========== K线接口 ==========
@app.get("/api/v1/kline", response_model=KlineResponse)
async def get_kline(
symbol: str = Query(..., description="合约代码"),
period: str = Query("daily", description="周期: daily/weekly/5m/15m/30m/60m"),
start_date: Optional[str] = Query(None, description="开始日期 YYYY-MM-DD"),
end_date: Optional[str] = Query(None, description="结束日期 YYYY-MM-DD"),
limit: int = Query(500, ge=1, le=5000, description="返回条数"),
):
logger.info(f"[API-查询K线] 请求参数: symbol={symbol}, period={period}, start_date={start_date}, end_date={end_date}, limit={limit}")
items = kline_service.get_kline(
symbol=symbol,
period=period,
start_date=start_date,
end_date=end_date,
limit=limit,
)
logger.info(f"[API-查询K线] 返回 {len(items)} 条记录")
return KlineResponse(
symbol=symbol,
period=period,
total=len(items),
items=[KlineItem(**item) for item in items],
)
@app.post("/api/v1/kline/sync")
async def sync_kline(req: KlineRequest):
"""从数据源同步K线数据"""
logger.info(f"[API-同步K线] 请求参数: symbol={req.symbol}, period={req.period}, start_date={req.start_date}, end_date={req.end_date}")
try:
start = req.start_date or "2020-01-01"
end = req.end_date or datetime.now().strftime("%Y-%m-%d")
logger.info(f"[API-同步K线] 使用日期范围: {start} ~ {end}")
if req.period == "daily":
count = kline_service.sync_daily(req.symbol, start, end)
elif req.period == "weekly":
count = kline_service.sync_weekly(req.symbol, start, end)
else:
count = kline_service.sync_intraday(req.symbol, req.period, start, end)
logger.info(f"[API-同步K线] 同步成功,共同步 {count} 条记录")
return {"code": 0, "message": "同步成功", "data": {"synced": count}}
except Exception as e:
logger.error(f"[API-同步K线] 同步失败: {e}", exc_info=True)
return {"code": 1, "message": f"同步失败: {str(e)}", "data": None}
@app.post("/api/v1/kline/batch-sync", response_model=BatchSyncResult)
async def batch_sync_kline(req: BatchSyncRequest):
"""批量同步K线数据"""
logger.info(f"[API-批量同步K线] 请求参数: symbols={req.symbols}, period={req.period}, start_date={req.start_date}, end_date={req.end_date}")
try:
result = kline_service.batch_sync(
symbols=req.symbols,
period=req.period,
start_date=req.start_date,
end_date=req.end_date,
)
logger.info(f"[API-批量同步K线] 同步完成: 成功={result['success']}, 失败={result['failed']}, 总记录={result['total_records']}")
return BatchSyncResult(**result)
except Exception as e:
logger.error(f"[API-批量同步K线] 同步失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"批量同步失败: {str(e)}")
# ========== 数据源管理接口 ==========
@app.get("/api/v1/datasources")
async def list_datasources():
"""获取所有数据源状态"""
sources = DataSourceManager.get_all_sources_status()
return {"code": 0, "data": sources}
@app.post("/api/v1/datasources")
async def create_datasource(req: DataSourceCreate):
"""创建数据源配置"""
from app.database import SessionLocal
db = SessionLocal()
try:
existing = db.query(DataSourceConfig).filter(
DataSourceConfig.source_name == req.source_name
).first()
if existing:
return {"code": 1, "message": "数据源已存在"}
cfg = DataSourceConfig(
source_name=req.source_name,
display_name=req.display_name or req.source_name,
is_enabled=False,
config_json=req.config_json or {},
priority=req.priority,
status="unknown",
)
db.add(cfg)
db.commit()
return {"code": 0, "message": "创建成功", "data": {"id": cfg.id}}
except Exception as e:
db.rollback()
return {"code": 1, "message": f"创建失败: {str(e)}"}
finally:
db.close()
@app.put("/api/v1/datasources/{source_name}")
async def update_datasource(source_name: str, req: DataSourceConfigUpdate):
"""更新数据源配置"""
from app.database import SessionLocal
db = SessionLocal()
try:
cfg = db.query(DataSourceConfig).filter(
DataSourceConfig.source_name == source_name
).first()
if not cfg:
return {"code": 1, "message": "数据源不存在"}
if req.is_enabled is not None:
cfg.is_enabled = req.is_enabled
if req.config_json is not None:
import json
cfg.config_json = json.dumps(req.config_json)
if req.priority is not None:
cfg.priority = req.priority
db.commit()
# 重新加载数据源
DataSourceManager.load_enabled_sources()
return {"code": 0, "message": "更新成功"}
except Exception as e:
db.rollback()
return {"code": 1, "message": f"更新失败: {str(e)}"}
finally:
db.close()
@app.post("/api/v1/datasources/{source_name}/test")
async def test_datasource(source_name: str):
"""测试数据源连接"""
source = DataSourceManager.get_source(source_name)
if not source:
# 尝试创建临时实例测试
from app.database import SessionLocal
import json
db = SessionLocal()
try:
cfg = db.query(DataSourceConfig).filter(
DataSourceConfig.source_name == source_name
).first()
if not cfg:
return {"code": 1, "message": "数据源不存在"}
config = json.loads(cfg.config_json) if cfg.config_json else {}
# 动态获取数据源类
source_class = DataSourceManager._source_map.get(source_name)
if not source_class:
return {"code": 1, "message": "不支持的数据源类型"}
source = source_class(config)
finally:
db.close()
ok, msg = source.health_check()
if ok:
# 更新状态
from app.database import SessionLocal
db = SessionLocal()
try:
cfg = db.query(DataSourceConfig).filter(
DataSourceConfig.source_name == source_name
).first()
if cfg:
cfg.status = "ok"
cfg.error_msg = None
db.commit()
finally:
db.close()
return {"code": 0, "message": "连接成功", "data": {"status": "ok"}}
else:
from app.database import SessionLocal
db = SessionLocal()
try:
cfg = db.query(DataSourceConfig).filter(
DataSourceConfig.source_name == source_name
).first()
if cfg:
cfg.status = "error"
cfg.error_msg = msg
db.commit()
finally:
db.close()
return {"code": 1, "message": f"连接失败: {msg}", "data": {"status": "error"}}