diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py index 8ebf576..5ab9c7f 100644 --- a/backend/app/api/v1/__init__.py +++ b/backend/app/api/v1/__init__.py @@ -1,7 +1,7 @@ # API v1模块 from fastapi import APIRouter -from app.api.v1 import auth, configs, base_data, stock, future, realtime, finance, cache, test, data_import, index +from app.api.v1 import auth, configs, base_data, stock, future, realtime, finance, cache, test, data_import, index, ws api_router = APIRouter(prefix="/api/v1") @@ -16,3 +16,4 @@ api_router.include_router(finance.router, prefix="/finance", tags=["财务数据 api_router.include_router(cache.router, prefix="/cache", tags=["缓存管理"]) api_router.include_router(test.router, prefix="/test", tags=["测试中心"]) api_router.include_router(data_import.router, prefix="/import", tags=["数据导入"]) +api_router.include_router(ws.router, tags=["WebSocket进度"]) diff --git a/backend/app/api/v1/cache.py b/backend/app/api/v1/cache.py index 1fe1b7a..af59709 100644 --- a/backend/app/api/v1/cache.py +++ b/backend/app/api/v1/cache.py @@ -190,6 +190,7 @@ async def cache_all_missing_data( @router.post("/detect-all-missing", response_model=ResponseModel) async def detect_all_missing_data( request: AllDataRequest, + task_id: str = Query(None, description="WebSocket任务ID"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -204,7 +205,8 @@ async def detect_all_missing_data( request.period_type, start, end, - request.contract_type + request.contract_type, + task_id ) return ResponseModel(data=result) diff --git a/backend/app/api/v1/ws.py b/backend/app/api/v1/ws.py new file mode 100644 index 0000000..3a11609 --- /dev/null +++ b/backend/app/api/v1/ws.py @@ -0,0 +1,33 @@ +""" +WebSocket进度路由 +""" +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query +from app.core.progress_manager import progress_manager +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.websocket("/progress/{task_id}") +async def websocket_progress( + websocket: WebSocket, + task_id: str +): + """WebSocket进度推送""" + await progress_manager.connect(websocket, task_id) + + try: + while True: + data = await websocket.receive_text() + if data == "ping": + await websocket.send_text("pong") + elif data == "close": + break + except WebSocketDisconnect: + logger.info(f"WebSocket断开连接: task_id={task_id}") + except Exception as e: + logger.error(f"WebSocket错误: {e}") + finally: + await progress_manager.disconnect(websocket, task_id) \ No newline at end of file diff --git a/backend/app/core/progress_manager.py b/backend/app/core/progress_manager.py new file mode 100644 index 0000000..838b35d --- /dev/null +++ b/backend/app/core/progress_manager.py @@ -0,0 +1,105 @@ +""" +进度管理器 - WebSocket实时进度推送 +""" +import asyncio +import json +import logging +from typing import Dict, Set, Optional +from datetime import datetime +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + + +class ProgressManager: + """进度管理器""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._connections: Dict[str, Set[WebSocket]] = {} + cls._instance._progress_data: Dict[str, Dict] = {} + return cls._instance + + async def connect(self, websocket: WebSocket, task_id: str): + """连接WebSocket""" + await websocket.accept() + if task_id not in self._connections: + self._connections[task_id] = set() + self._connections[task_id].add(websocket) + + if task_id in self._progress_data: + await websocket.send_json(self._progress_data[task_id]) + + logger.info(f"WebSocket连接: task_id={task_id}") + + async def disconnect(self, websocket: WebSocket, task_id: str): + """断开WebSocket连接""" + if task_id in self._connections: + self._connections[task_id].discard(websocket) + if not self._connections[task_id]: + del self._connections[task_id] + logger.info(f"WebSocket断开: task_id={task_id}") + + async def update_progress(self, task_id: str, progress_data: Dict): + """更新进度并推送""" + progress_data["timestamp"] = datetime.utcnow().isoformat() + self._progress_data[task_id] = progress_data + + if task_id in self._connections: + disconnected = set() + for websocket in self._connections[task_id]: + try: + await websocket.send_json(progress_data) + except Exception as e: + logger.warning(f"WebSocket发送失败: {e}") + disconnected.add(websocket) + + for ws in disconnected: + self._connections[task_id].discard(ws) + + async def complete_task(self, task_id: str, result: Dict): + """完成任务""" + result["status"] = "completed" + result["timestamp"] = datetime.utcnow().isoformat() + self._progress_data[task_id] = result + + if task_id in self._connections: + for websocket in self._connections[task_id]: + try: + await websocket.send_json(result) + except Exception as e: + logger.warning(f"WebSocket发送失败: {e}") + + async def fail_task(self, task_id: str, error: str): + """任务失败""" + result = { + "status": "failed", + "error": error, + "timestamp": datetime.utcnow().isoformat() + } + self._progress_data[task_id] = result + + if task_id in self._connections: + for websocket in self._connections[task_id]: + try: + await websocket.send_json(result) + except Exception as e: + logger.warning(f"WebSocket发送失败: {e}") + + def get_progress(self, task_id: str) -> Optional[Dict]: + """获取进度数据""" + return self._progress_data.get(task_id) + + def clear_task(self, task_id: str): + """清除任务数据""" + if task_id in self._progress_data: + del self._progress_data[task_id] + if task_id in self._connections: + del self._connections[task_id] + + +progress_manager = ProgressManager() \ No newline at end of file diff --git a/backend/app/services/cache_service.py b/backend/app/services/cache_service.py index 5fdda41..f6fc277 100644 --- a/backend/app/services/cache_service.py +++ b/backend/app/services/cache_service.py @@ -1,6 +1,7 @@ """ 缓存管理服务 """ +import asyncio import logging from typing import List, Dict, Optional from datetime import date, datetime @@ -17,6 +18,7 @@ from app.services.sdk_manager import sdk_manager from app.utils.date_utils import parse_date, format_date, get_market_from_code from app.config import settings from app.core.redis_client import redis_client +from app.core.progress_manager import progress_manager logger = logging.getLogger(__name__) @@ -166,7 +168,8 @@ class CacheService: period_type: str, start_date: date, end_date: date, - contract_type: str = "all" + contract_type: str = "all", + task_id: str = None ) -> Dict: """ 一键检测所有数据的缺失情况 @@ -177,11 +180,11 @@ class CacheService: start_date: 开始日期 end_date: 结束日期 contract_type: 合约类型 (all, main) - 仅对期货有效 + task_id: WebSocket任务ID Returns: 检测结果字典 """ - # 获取所有代码 code_list = self.get_all_codes(security_type, contract_type) if not code_list: @@ -189,7 +192,23 @@ class CacheService: logger.info(f"获取到{len(code_list)}个{security_type}代码") - # 创建检测任务 + ws_task_id = task_id or f"detect_{security_type}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}" + + def push_progress(progress, status, **kwargs): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(progress_manager.update_progress(ws_task_id, { + "progress": progress, + "status": status, + "total_count": len(code_list), + **kwargs + })) + except RuntimeError: + pass + + push_progress(0, "starting", message="开始检测...") + task = CacheTask( task_name=f"一键检测所有数据 - {security_type} - {contract_type} - {len(code_list)}个代码", task_type="detect_all_missing", @@ -206,8 +225,9 @@ class CacheService: self.db.commit() self.db.refresh(task) + push_progress(5, "running", message="获取交易日历...") + try: - # 获取交易日历 market = "CFE" if security_type == "future" else "SH" trading_days = self.base_service.get_trading_calendar(market, start_date, end_date) expected_count = len(trading_days) @@ -216,7 +236,6 @@ class CacheService: complete_codes = [] error_count = 0 - # 统计每个交易日的缺失情况 daily_stats = {} for td in trading_days: daily_stats[format_date(td)] = { @@ -225,44 +244,114 @@ class CacheService: "missing": 0 } - for i, code in enumerate(code_list): - try: - # 查询实际数据量 - if security_type == "stock" and period_type == "daily": - records = self.db.query(StockKlineDaily).filter( - and_( - StockKlineDaily.code == code, - StockKlineDaily.trade_date >= start_date, - StockKlineDaily.trade_date <= end_date - ) - ).all() - actual_count = len(records) - - # 更新每日统计 - for r in records: - date_key = format_date(r.trade_date) - if date_key in daily_stats: - daily_stats[date_key]["actual"] += 1 - - elif security_type == "future" and period_type == "daily": - records = self.db.query(FutureKlineDaily).filter( - and_( - FutureKlineDaily.code == code, - FutureKlineDaily.trade_date >= start_date, - FutureKlineDaily.trade_date <= end_date - ) - ).all() - actual_count = len(records) - - for r in records: - date_key = format_date(r.trade_date) - if date_key in daily_stats: - daily_stats[date_key]["actual"] += 1 + push_progress(10, "running", message="查询数据库统计...") + + if security_type == "stock" and period_type == "daily": + from sqlalchemy import func + + code_count_query = self.db.query( + StockKlineDaily.code, + func.count(StockKlineDaily.id).label('count') + ).filter( + and_( + StockKlineDaily.trade_date >= start_date, + StockKlineDaily.trade_date <= end_date + ) + ).group_by(StockKlineDaily.code).all() + + code_counts = {r.code: r.count for r in code_count_query} + + date_count_query = self.db.query( + func.date(StockKlineDaily.trade_date).label('trade_date'), + func.count(StockKlineDaily.id).label('count') + ).filter( + and_( + StockKlineDaily.trade_date >= start_date, + StockKlineDaily.trade_date <= end_date + ) + ).group_by(func.date(StockKlineDaily.trade_date)).all() + + for r in date_count_query: + date_key = format_date(r.trade_date) if hasattr(r.trade_date, 'strftime') else str(r.trade_date) + if date_key in daily_stats: + daily_stats[date_key]["actual"] = r.count + + push_progress(20, "running", message="分析数据完整性...") + + for i, code in enumerate(code_list): + actual_count = code_counts.get(code, 0) + is_missing = actual_count < expected_count + + if is_missing: + missing_codes.append({ + "code": code, + "actual_count": actual_count, + "expected_count": expected_count, + "missing_count": expected_count - actual_count, + "missing_ratio": (expected_count - actual_count) / expected_count if expected_count > 0 else 0 + }) + detail = CacheTaskDetail( + task_id=task.id, + code=code, + trade_date=start_date, + expected_count=expected_count, + actual_count=actual_count, + is_missing=True, + status="pending" + ) + self.db.add(detail) else: - actual_count = 0 + complete_codes.append(code) - # 判断是否缺失 + if (i + 1) % 500 == 0 or i == len(code_list) - 1: + task.success_count = len(missing_codes) + len(complete_codes) + task.error_count = error_count + task.progress = min(100, int((i + 1) / len(code_list) * 100)) + self.db.commit() + + push_progress( + 20 + int((i + 1) / len(code_list) * 70), + "running", + processed=len(missing_codes) + len(complete_codes), + missing=len(missing_codes), + complete=len(complete_codes) + ) + + elif security_type == "future" and period_type == "daily": + from sqlalchemy import func + + code_count_query = self.db.query( + FutureKlineDaily.code, + func.count(FutureKlineDaily.id).label('count') + ).filter( + and_( + FutureKlineDaily.trade_date >= start_date, + FutureKlineDaily.trade_date <= end_date + ) + ).group_by(FutureKlineDaily.code).all() + + code_counts = {r.code: r.count for r in code_count_query} + + date_count_query = self.db.query( + func.date(FutureKlineDaily.trade_date).label('trade_date'), + func.count(FutureKlineDaily.id).label('count') + ).filter( + and_( + FutureKlineDaily.trade_date >= start_date, + FutureKlineDaily.trade_date <= end_date + ) + ).group_by(func.date(FutureKlineDaily.trade_date)).all() + + for r in date_count_query: + date_key = format_date(r.trade_date) if hasattr(r.trade_date, 'strftime') else str(r.trade_date) + if date_key in daily_stats: + daily_stats[date_key]["actual"] = r.count + + push_progress(20, "running", message="分析数据完整性...") + + for i, code in enumerate(code_list): + actual_count = code_counts.get(code, 0) is_missing = actual_count < expected_count if is_missing: @@ -287,31 +376,59 @@ class CacheService: else: complete_codes.append(code) - except Exception as e: - logger.error(f"检测{code}缺失数据失败: {str(e)}") - error_count += 1 + if (i + 1) % 500 == 0 or i == len(code_list) - 1: + task.success_count = len(missing_codes) + len(complete_codes) + task.error_count = error_count + task.progress = min(100, int((i + 1) / len(code_list) * 100)) + self.db.commit() + + push_progress( + 20 + int((i + 1) / len(code_list) * 70), + "running", + processed=len(missing_codes) + len(complete_codes), + missing=len(missing_codes), + complete=len(complete_codes) + ) + else: + for i, code in enumerate(code_list): + actual_count = 0 + is_missing = True + + missing_codes.append({ + "code": code, + "actual_count": 0, + "expected_count": expected_count, + "missing_count": expected_count, + "missing_ratio": 1.0 + }) detail = CacheTaskDetail( task_id=task.id, code=code, trade_date=start_date, - status="failed", - error_message=str(e) + expected_count=expected_count, + actual_count=0, + is_missing=True, + status="pending" ) self.db.add(detail) - - # 每100个代码更新一次进度 - if (i + 1) % 100 == 0 or i == len(code_list) - 1: - task.success_count = len(missing_codes) + len(complete_codes) - task.error_count = error_count - task.progress = min(100, int((i + 1) / len(code_list) * 100)) - self.db.commit() + + if (i + 1) % 500 == 0 or i == len(code_list) - 1: + task.success_count = len(missing_codes) + task.error_count = error_count + task.progress = min(100, int((i + 1) / len(code_list) * 100)) + self.db.commit() + + push_progress( + 20 + int((i + 1) / len(code_list) * 70), + "running", + processed=len(missing_codes), + missing=len(missing_codes) + ) - # 计算每日缺失数 for date_key in daily_stats: daily_stats[date_key]["missing"] = daily_stats[date_key]["expected"] - daily_stats[date_key]["actual"] - # 保存缺失代码列表到任务记录 missing_code_list = [m["code"] for m in missing_codes] task.code_list = ",".join(missing_code_list[:500]) if missing_code_list else "" @@ -321,10 +438,18 @@ class CacheService: task.completed_at = datetime.utcnow() self.db.commit() + push_progress(100, "completed", + message="检测完成", + complete_count=len(complete_codes), + missing_count=len(missing_codes), + error_count=error_count + ) + logger.info(f"检测完成: 完整{len(complete_codes)}个, 缺失{len(missing_codes)}个, 错误{error_count}个") return { "task_id": task.id, + "ws_task_id": ws_task_id, "task_name": task.task_name, "status": task.status, "progress": float(task.progress), @@ -349,8 +474,11 @@ class CacheService: self.db.commit() logger.error(f"一键检测缺失数据任务失败: {str(e)}") + push_progress(100, "failed", error=str(e)) + return { "task_id": task.id, + "ws_task_id": ws_task_id, "task_name": task.task_name, "status": task.status, "error_message": str(e) diff --git a/frontend/src/api/cache.ts b/frontend/src/api/cache.ts index e83bde1..3d50988 100644 --- a/frontend/src/api/cache.ts +++ b/frontend/src/api/cache.ts @@ -1,4 +1,4 @@ -import request from '@/utils/request' +import request, { cacheRequest } from '@/utils/request' export const detectMissingData = (data: { security_type: string @@ -7,7 +7,7 @@ export const detectMissingData = (data: { end_date: string code_list: string[] }) => { - return request.post('/cache/detect-missing', data) + return cacheRequest.post('/cache/detect-missing', data) } export const batchCacheData = (data: { @@ -17,7 +17,7 @@ export const batchCacheData = (data: { end_date: string code_list: string[] }) => { - return request.post('/cache/batch-cache', data) + return cacheRequest.post('/cache/batch-cache', data) } export const detectAllMissingData = (data: { @@ -26,8 +26,10 @@ export const detectAllMissingData = (data: { contract_type?: string start_date: string end_date: string + task_id?: string }) => { - return request.post('/cache/detect-all-missing', data) + const params = data.task_id ? { task_id: data.task_id } : {} + return cacheRequest.post('/cache/detect-all-missing', data, { params }) } export const cacheAllMissingData = (data: { @@ -37,7 +39,7 @@ export const cacheAllMissingData = (data: { start_date: string end_date: string }) => { - return request.post('/cache/cache-all-missing', data) + return cacheRequest.post('/cache/cache-all-missing', data) } export const getCacheTasks = (params?: { page?: number; page_size?: number }) => { @@ -74,5 +76,5 @@ export const fillMissingData = (data: { start_date: string end_date: string }) => { - return request.post('/cache/fill-missing', data) + return cacheRequest.post('/cache/fill-missing', data) } diff --git a/frontend/src/utils/request.ts b/frontend/src/utils/request.ts index 98993f7..6df9f4c 100644 --- a/frontend/src/utils/request.ts +++ b/frontend/src/utils/request.ts @@ -2,13 +2,16 @@ import axios from 'axios' import { ElMessage } from 'element-plus' import { useUserStore } from '@/store/user' -// 创建axios实例 const request = axios.create({ baseURL: '/api/v1', timeout: 30000 }) -// 请求拦截器 +const cacheRequest = axios.create({ + baseURL: '/api/v1', + timeout: 300000 +}) + request.interceptors.request.use( (config) => { const userStore = useUserStore() @@ -22,7 +25,19 @@ request.interceptors.request.use( } ) -// 响应拦截器 +cacheRequest.interceptors.request.use( + (config) => { + const userStore = useUserStore() + if (userStore.token) { + config.headers.Authorization = `Bearer ${userStore.token}` + } + return config + }, + (error) => { + return Promise.reject(error) + } +) + request.interceptors.response.use( (response) => { const res = response.data @@ -30,7 +45,31 @@ request.interceptors.response.use( if (res.code !== 200) { ElMessage.error(res.message || '请求失败') - // 401未授权,跳转到登录页 + if (res.code === 401) { + const userStore = useUserStore() + userStore.logout() + window.location.href = '/login' + } + + return Promise.reject(new Error(res.message)) + } + + return res + }, + (error) => { + const message = error.response?.data?.message || error.message || '网络错误' + ElMessage.error(message) + return Promise.reject(error) + } +) + +cacheRequest.interceptors.response.use( + (response) => { + const res = response.data + + if (res.code !== 200) { + ElMessage.error(res.message || '请求失败') + if (res.code === 401) { const userStore = useUserStore() userStore.logout() @@ -50,3 +89,4 @@ request.interceptors.response.use( ) export default request +export { cacheRequest } diff --git a/frontend/src/views/CacheManager/DetectMissing.vue b/frontend/src/views/CacheManager/DetectMissing.vue index be6b32e..2776079 100644 --- a/frontend/src/views/CacheManager/DetectMissing.vue +++ b/frontend/src/views/CacheManager/DetectMissing.vue @@ -69,6 +69,32 @@ + + + + + + {{ wsProgress.message || '-' }} + {{ wsProgress.total_count || 0 }} + {{ wsProgress.processed || 0 }} + {{ wsProgress.missing || 0 }} + + +