You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
398 lines
11 KiB
398 lines
11 KiB
"""
|
|
缓存管理路由
|
|
"""
|
|
from typing import List
|
|
from fastapi import APIRouter, Depends, Query, BackgroundTasks
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.db.session import get_db
|
|
from app.schemas.base import ResponseModel, PaginatedResponse
|
|
from app.schemas.cache import (
|
|
DetectMissingRequest, DetectMissingResponse,
|
|
BatchCacheRequest, CacheTaskResponse, CacheStatusResponse,
|
|
AllDataRequest
|
|
)
|
|
from app.services.cache_service import CacheService
|
|
from app.core.security import get_current_user
|
|
from app.models.user import User
|
|
from app.utils.date_utils import parse_date, format_date
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def run_cache_task(
|
|
security_type: str,
|
|
period_type: str,
|
|
start_date_str: str,
|
|
end_date_str: str,
|
|
contract_type: str,
|
|
task_id: int
|
|
):
|
|
"""后台执行缓存任务"""
|
|
from app.db.session import SessionLocal
|
|
from app.utils.date_utils import parse_date
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
service = CacheService(db)
|
|
start_date = parse_date(start_date_str)
|
|
end_date = parse_date(end_date_str)
|
|
service._execute_cache_task(
|
|
task_id,
|
|
security_type,
|
|
period_type,
|
|
start_date,
|
|
end_date,
|
|
contract_type
|
|
)
|
|
except Exception as e:
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"后台缓存任务失败: {str(e)}")
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def run_fill_missing_task(
|
|
security_type: str,
|
|
period_type: str,
|
|
start_date_str: str,
|
|
end_date_str: str,
|
|
missing_codes: List[str],
|
|
task_id: int
|
|
):
|
|
"""后台执行补齐缺失数据任务"""
|
|
from app.db.session import SessionLocal
|
|
from app.utils.date_utils import parse_date
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
service = CacheService(db)
|
|
start_date = parse_date(start_date_str)
|
|
end_date = parse_date(end_date_str)
|
|
service._execute_fill_missing_task(
|
|
task_id,
|
|
security_type,
|
|
period_type,
|
|
start_date,
|
|
end_date,
|
|
missing_codes
|
|
)
|
|
except Exception as e:
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logger.error(f"后台补齐任务失败: {str(e)}")
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@router.post("/fill-missing", response_model=ResponseModel)
|
|
async def fill_missing_data(
|
|
request: AllDataRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""一键补齐缺失数据(异步执行,只处理缺失代码)"""
|
|
service = CacheService(db)
|
|
start = parse_date(request.start_date)
|
|
end = parse_date(request.end_date)
|
|
|
|
# 先执行检测获取缺失代码列表
|
|
result = service.detect_all_missing_data(
|
|
request.security_type,
|
|
request.period_type,
|
|
start,
|
|
end,
|
|
request.contract_type
|
|
)
|
|
|
|
missing_codes = result.get("missing_code_list", [])
|
|
|
|
if not missing_codes:
|
|
return ResponseModel(data={
|
|
"task_id": None,
|
|
"message": "没有缺失数据需要补齐",
|
|
"missing_count": 0
|
|
})
|
|
|
|
# 创建补齐任务记录
|
|
task = service._create_fill_missing_task(
|
|
request.security_type,
|
|
request.period_type,
|
|
start,
|
|
end,
|
|
missing_codes
|
|
)
|
|
|
|
# 在后台执行补齐任务
|
|
background_tasks.add_task(
|
|
run_fill_missing_task,
|
|
request.security_type,
|
|
request.period_type,
|
|
request.start_date,
|
|
request.end_date,
|
|
missing_codes,
|
|
task.id
|
|
)
|
|
|
|
return ResponseModel(data={
|
|
"task_id": task.id,
|
|
"task_name": task.task_name,
|
|
"status": task.status,
|
|
"total_count": task.total_count,
|
|
"progress": float(task.progress),
|
|
"missing_count": len(missing_codes)
|
|
})
|
|
|
|
|
|
@router.post("/cache-all-missing", response_model=ResponseModel)
|
|
async def cache_all_missing_data(
|
|
request: AllDataRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""一键缓存所有缺失数据(异步执行)"""
|
|
service = CacheService(db)
|
|
start = parse_date(request.start_date)
|
|
end = parse_date(request.end_date)
|
|
|
|
# 创建任务记录
|
|
task = service._create_cache_task(
|
|
request.security_type,
|
|
request.period_type,
|
|
start,
|
|
end,
|
|
request.contract_type
|
|
)
|
|
|
|
# 在后台执行缓存任务
|
|
background_tasks.add_task(
|
|
run_cache_task,
|
|
request.security_type,
|
|
request.period_type,
|
|
request.start_date,
|
|
request.end_date,
|
|
request.contract_type,
|
|
task.id
|
|
)
|
|
|
|
return ResponseModel(data={
|
|
"task_id": task.id,
|
|
"task_name": task.task_name,
|
|
"status": task.status,
|
|
"total_count": task.total_count,
|
|
"progress": float(task.progress)
|
|
})
|
|
|
|
|
|
@router.post("/detect-all-missing", response_model=ResponseModel)
|
|
async def detect_all_missing_data(
|
|
request: AllDataRequest,
|
|
task_id: str = Query(None, description="WebSocket任务ID"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""一键检测所有数据的缺失情况"""
|
|
try:
|
|
service = CacheService(db)
|
|
start = parse_date(request.start_date)
|
|
end = parse_date(request.end_date)
|
|
|
|
result = service.detect_all_missing_data(
|
|
request.security_type,
|
|
request.period_type,
|
|
start,
|
|
end,
|
|
request.contract_type,
|
|
task_id
|
|
)
|
|
|
|
return ResponseModel(data=result)
|
|
except ValueError as e:
|
|
return ResponseModel(code=400, message=f"参数错误: {str(e)}")
|
|
except RuntimeError as e:
|
|
return ResponseModel(code=500, message=f"服务错误: {str(e)}")
|
|
except Exception as e:
|
|
return ResponseModel(code=500, message=f"检测失败: {str(e)}")
|
|
|
|
|
|
@router.post("/detect-missing", response_model=ResponseModel)
|
|
async def detect_missing_data(
|
|
request: DetectMissingRequest,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""检测缺失数据"""
|
|
service = CacheService(db)
|
|
start = parse_date(request.start_date)
|
|
end = parse_date(request.end_date)
|
|
|
|
task = service.detect_missing_data(
|
|
request.security_type,
|
|
request.period_type,
|
|
start,
|
|
end,
|
|
request.code_list
|
|
)
|
|
|
|
# 获取缺失详情
|
|
details = service.get_task_details(task.id)
|
|
missing_codes = [d for d in details if d.is_missing]
|
|
|
|
missing_info = []
|
|
for code in request.code_list:
|
|
code_details = [d for d in details if d.code == code and d.is_missing]
|
|
if code_details:
|
|
missing_info.append({
|
|
"code": code,
|
|
"missing_dates": [{
|
|
"date": format_date(d.trade_date),
|
|
"expected": d.expected_count,
|
|
"actual": d.actual_count,
|
|
"missing_ratio": (d.expected_count - d.actual_count) / d.expected_count if d.expected_count > 0 else 0
|
|
} for d in code_details]
|
|
})
|
|
|
|
return ResponseModel(data={
|
|
"task_id": task.id,
|
|
"total_codes": len(request.code_list),
|
|
"missing_codes": missing_info
|
|
})
|
|
|
|
|
|
@router.post("/batch-cache", response_model=ResponseModel[CacheTaskResponse])
|
|
async def batch_cache_data(
|
|
request: BatchCacheRequest,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""批量缓存数据"""
|
|
service = CacheService(db)
|
|
start = parse_date(request.start_date)
|
|
end = parse_date(request.end_date)
|
|
|
|
task = service.batch_cache_data(
|
|
request.security_type,
|
|
request.period_type,
|
|
start,
|
|
end,
|
|
request.code_list
|
|
)
|
|
|
|
return ResponseModel(data=CacheTaskResponse.model_validate(task))
|
|
|
|
|
|
@router.get("/tasks", response_model=ResponseModel)
|
|
async def get_cache_tasks(
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(20, ge=1, le=100),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取缓存任务列表"""
|
|
service = CacheService(db)
|
|
result = service.get_tasks(page, page_size)
|
|
|
|
return ResponseModel(data={
|
|
"items": [CacheTaskResponse.model_validate(t) for t in result["items"]],
|
|
"total": result["total"],
|
|
"page": result["page"],
|
|
"page_size": result["page_size"],
|
|
"total_pages": result["total_pages"]
|
|
})
|
|
|
|
|
|
@router.get("/tasks/{task_id}", response_model=ResponseModel)
|
|
async def get_cache_task(
|
|
task_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取缓存任务详情"""
|
|
service = CacheService(db)
|
|
task = service.get_task(task_id)
|
|
|
|
if not task:
|
|
return ResponseModel(code=404, message="任务不存在")
|
|
|
|
details = service.get_task_details(task_id)
|
|
|
|
return ResponseModel(data={
|
|
"task": CacheTaskResponse.model_validate(task),
|
|
"details": [{
|
|
"id": d.id,
|
|
"code": d.code,
|
|
"trade_date": d.trade_date.isoformat() if d.trade_date else None,
|
|
"expected_count": d.expected_count,
|
|
"actual_count": d.actual_count,
|
|
"is_missing": bool(d.is_missing),
|
|
"status": d.status,
|
|
"error_message": d.error_message
|
|
} for d in details]
|
|
})
|
|
|
|
|
|
@router.delete("/tasks/{task_id}", response_model=ResponseModel)
|
|
async def cancel_cache_task(
|
|
task_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""取消缓存任务"""
|
|
service = CacheService(db)
|
|
success = service.cancel_task(task_id)
|
|
|
|
if success:
|
|
return ResponseModel(message="任务已取消")
|
|
else:
|
|
return ResponseModel(code=400, message="任务不存在或已完成")
|
|
|
|
|
|
@router.get("/status/{code}", response_model=ResponseModel)
|
|
async def get_cache_status(
|
|
code: str,
|
|
security_type: str = Query("stock", description="证券类型: stock, future"),
|
|
period_type: str = Query("daily", description="周期类型: daily, min1, min5, etc."),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取代码缓存状态"""
|
|
service = CacheService(db)
|
|
status = service.get_cache_status(code, security_type, period_type)
|
|
|
|
return ResponseModel(data=status)
|
|
|
|
|
|
@router.get("/future-varieties", response_model=ResponseModel)
|
|
async def get_future_varieties(
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取期货品种列表"""
|
|
service = CacheService(db)
|
|
varieties = service.get_future_varieties()
|
|
|
|
return ResponseModel(data={"varieties": varieties})
|
|
|
|
|
|
@router.get("/main-contracts", response_model=ResponseModel)
|
|
async def get_main_contracts(
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""获取所有品种的主力合约"""
|
|
from app.services.sdk_manager import sdk_manager
|
|
|
|
adapter = sdk_manager.get_default_connection()
|
|
if not adapter:
|
|
return ResponseModel(code=500, message="SDK连接失败")
|
|
|
|
main_contracts = adapter.get_all_main_contracts()
|
|
|
|
return ResponseModel(data={"main_contracts": main_contracts})
|
|
|
|
|
|
from app.utils.date_utils import format_date
|