You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

205 lines
6.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
实时数据服务
"""
import asyncio
import logging
from typing import Dict, Set, List, Callable, Optional
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from fastapi import WebSocket
from app.models.realtime import RealtimeSnapshot
from app.services.base_data_service import BaseDataService
from app.config import settings
logger = logging.getLogger(__name__)
class RealtimeManager:
"""实时数据管理器(单例)"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self.subscribers: Dict[str, Set[WebSocket]] = {}
self.code_callbacks: Dict[str, List[Callable]] = {}
self._adapter = None
self._initialized = True
self._lock = asyncio.Lock()
async def subscribe(self, websocket: WebSocket, codes: List[str]):
"""
客户端订阅实时数据
Args:
websocket: WebSocket连接
codes: 代码列表
"""
await websocket.accept()
async with self._lock:
for code in codes:
if code not in self.subscribers:
self.subscribers[code] = set()
# 启动SDK订阅
await self._start_sdk_subscription(code)
self.subscribers[code].add(websocket)
logger.info(f"WebSocket {id(websocket)} 订阅了: {codes}")
async def unsubscribe(self, websocket: WebSocket, codes: List[str] = None):
"""
取消订阅
Args:
websocket: WebSocket连接
codes: 代码列表None表示取消所有
"""
async with self._lock:
codes_to_remove = codes if codes else list(self.subscribers.keys())
for code in codes_to_remove:
if code in self.subscribers:
self.subscribers[code].discard(websocket)
# 如果没有订阅者了取消SDK订阅
if not self.subscribers[code]:
del self.subscribers[code]
await self._stop_sdk_subscription(code)
logger.info(f"WebSocket {id(websocket)} 取消订阅")
async def _start_sdk_subscription(self, code: str):
"""启动SDK订阅"""
# 这里需要实现实际的SDK订阅逻辑
# 由于SDK的实时订阅是同步的回调需要在后台线程中运行
logger.info(f"开始SDK订阅: {code}")
async def _stop_sdk_subscription(self, code: str):
"""停止SDK订阅"""
logger.info(f"停止SDK订阅: {code}")
def on_sdk_data(self, code: str, data: dict):
"""
SDK数据回调
Args:
code: 代码
data: 数据字典
"""
# 保存到数据库
# self._save_snapshot(code, data)
# 推送给所有订阅者
if code in self.subscribers:
message = {
"type": "snapshot",
"code": code,
"data": data,
"timestamp": datetime.utcnow().isoformat()
}
# 异步推送
for ws in self.subscribers[code]:
asyncio.create_task(self._send_to_ws(ws, message))
async def _send_to_ws(self, websocket: WebSocket, message: dict):
"""发送消息到WebSocket"""
try:
await websocket.send_json(message)
except Exception as e:
logger.error(f"发送WebSocket消息失败: {str(e)}")
# 从订阅列表中移除
await self.unsubscribe(websocket)
def _save_snapshot(self, db: Session, code: str, data: dict):
"""保存快照到数据库"""
try:
expires_at = datetime.utcnow() + timedelta(days=settings.CACHE_AUTO_CLEANUP_DAYS)
snapshot = RealtimeSnapshot(
code=code,
security_type=data.get("security_type", "stock"),
trade_time=datetime.fromisoformat(data.get("trade_time", datetime.utcnow().isoformat())),
pre_close=data.get("pre_close"),
last=data.get("last"),
open=data.get("open"),
high=data.get("high"),
low=data.get("low"),
close=data.get("close"),
volume=data.get("volume"),
amount=data.get("amount"),
expires_at=expires_at
)
db.add(snapshot)
db.commit()
except Exception as e:
logger.error(f"保存快照失败: {str(e)}")
# 全局实时数据管理器实例
realtime_manager = RealtimeManager()
class RealtimeService:
"""实时数据服务"""
def __init__(self, db: Session):
self.db = db
self.base_service = BaseDataService(db)
self.manager = realtime_manager
def get_latest_snapshot(self, codes: List[str]) -> Dict[str, dict]:
"""
获取最新快照数据
Args:
codes: 代码列表
Returns:
快照数据字典
"""
result = {}
for code in codes:
# 查询最新的快照
snapshot = self.db.query(RealtimeSnapshot).filter(
RealtimeSnapshot.code == code
).order_by(RealtimeSnapshot.trade_time.desc()).first()
if snapshot:
result[code] = {
"code": snapshot.code,
"trade_time": snapshot.trade_time.isoformat(),
"pre_close": float(snapshot.pre_close) if snapshot.pre_close else None,
"last": float(snapshot.last) if snapshot.last else None,
"open": float(snapshot.open) if snapshot.open else None,
"high": float(snapshot.high) if snapshot.high else None,
"low": float(snapshot.low) if snapshot.low else None,
"close": float(snapshot.close) if snapshot.close else None,
"volume": int(snapshot.volume) if snapshot.volume else None,
"amount": float(snapshot.amount) if snapshot.amount else None
}
else:
result[code] = None
return result
async def subscribe_websocket(self, websocket: WebSocket, codes: List[str]):
"""订阅WebSocket"""
await self.manager.subscribe(websocket, codes)
async def unsubscribe_websocket(self, websocket: WebSocket, codes: List[str] = None):
"""取消WebSocket订阅"""
await self.manager.unsubscribe(websocket, codes)