You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
股票数据服务
"""
from typing import List, Optional, Dict
from datetime import date
from sqlalchemy.orm import Session
from sqlalchemy import and_
import pandas as pd
import logging
from app.models.stock import StockKlineDaily, StockKlineMin
from app.services.sdk_manager import sdk_manager
from app.services.base_data_service import BaseDataService
from app.utils.date_utils import parse_date, format_date, get_market_from_code
from app.utils.data_utils import dataframe_to_dict_list, merge_kline_data
logger = logging.getLogger(__name__)
class StockService:
"""股票数据服务"""
def __init__(self, db: Session):
self.db = db
self.base_service = BaseDataService(db)
def _get_adapter(self):
"""获取SDK适配器使用连接管理器"""
return sdk_manager.get_default_connection()
def get_kline(
self,
codes: List[str],
start_date: date,
end_date: date,
period: str = "daily"
) -> Dict[str, List[dict]]:
"""
获取股票K线数据带缓存
Args:
codes: 代码列表
start_date: 开始日期
end_date: 结束日期
period: 周期 (daily, min1, min5, min15, min30, min60)
Returns:
字典 {code: [kline_data]}
"""
result = {}
for code in codes:
try:
if period == "daily":
data = self._get_daily_kline_with_cache(code, start_date, end_date)
else:
data = self._get_min_kline_with_cache(code, start_date, end_date, period)
result[code] = data
except Exception as e:
logger.error(f"获取{code}的K线数据失败: {str(e)}")
result[code] = []
return result
def _get_daily_kline_with_cache(
self,
code: str,
start_date: date,
end_date: date
) -> List[dict]:
"""获取日线数据(带缓存)"""
# 1. 查询本地缓存
cached_records = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).order_by(StockKlineDaily.trade_date).all()
# 2. 检查数据完整性
cached_dates = {r.trade_date for r in cached_records}
expected_dates = set(self.base_service.get_trading_calendar(
get_market_from_code(code),
start_date,
end_date
))
missing_dates = expected_dates - cached_dates
# 3. 如果有缺失从SDK获取
if missing_dates:
try:
adapter = self._get_adapter()
if adapter:
sdk_data = adapter.get_kline([code], start_date, end_date, "daily")
if code in sdk_data and not sdk_data[code].empty:
# 保存到数据库
self._save_daily_kline(code, sdk_data[code])
# 重新查询
cached_records = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date >= start_date,
StockKlineDaily.trade_date <= end_date
)
).order_by(StockKlineDaily.trade_date).all()
except Exception as e:
logger.error(f"从SDK获取{code}数据失败: {str(e)}")
# 4. 转换为字典列表
return [
{
"trade_date": format_date(r.trade_date),
"open": float(r.open),
"high": float(r.high),
"low": float(r.low),
"close": float(r.close),
"volume": int(r.volume),
"amount": float(r.amount)
}
for r in cached_records
]
def _get_min_kline_with_cache(
self,
code: str,
start_date: date,
end_date: date,
period: str
) -> List[dict]:
"""获取分钟线数据(带缓存)"""
from datetime import datetime
start_datetime = datetime.combine(start_date, datetime.min.time())
end_datetime = datetime.combine(end_date, datetime.max.time())
# 1. 查询本地缓存
cached_records = self.db.query(StockKlineMin).filter(
and_(
StockKlineMin.code == code,
StockKlineMin.period_type == period,
StockKlineMin.trade_datetime >= start_datetime,
StockKlineMin.trade_datetime <= end_datetime
)
).order_by(StockKlineMin.trade_datetime).all()
# 2. 如果数据较少尝试从SDK获取
if len(cached_records) < 10:
try:
adapter = self._get_adapter()
if adapter:
sdk_data = adapter.get_kline([code], start_date, end_date, period)
if code in sdk_data and not sdk_data[code].empty:
self._save_min_kline(code, sdk_data[code], period)
# 重新查询
cached_records = self.db.query(StockKlineMin).filter(
and_(
StockKlineMin.code == code,
StockKlineMin.period_type == period,
StockKlineMin.trade_datetime >= start_datetime,
StockKlineMin.trade_datetime <= end_datetime
)
).order_by(StockKlineMin.trade_datetime).all()
except Exception as e:
logger.error(f"从SDK获取{code}分钟数据失败: {str(e)}")
return [
{
"trade_datetime": r.trade_datetime.isoformat(),
"open": float(r.open),
"high": float(r.high),
"low": float(r.low),
"close": float(r.close),
"volume": int(r.volume),
"amount": float(r.amount)
}
for r in cached_records
]
def _save_daily_kline(self, code: str, df: pd.DataFrame):
"""保存日线数据到数据库"""
if df.empty:
return
for idx, row in df.iterrows():
kline_time = row.get("kline_time")
if kline_time is None:
continue
trade_date = kline_time.date() if hasattr(kline_time, 'date') else parse_date(str(kline_time)[:10])
existing = self.db.query(StockKlineDaily).filter(
and_(
StockKlineDaily.code == code,
StockKlineDaily.trade_date == trade_date
)
).first()
if existing:
existing.open = float(row.get("open", 0))
existing.high = float(row.get("high", 0))
existing.low = float(row.get("low", 0))
existing.close = float(row.get("close", 0))
existing.volume = int(row.get("volume", 0))
existing.amount = float(row.get("amount", 0))
else:
record = StockKlineDaily(
code=code,
trade_date=trade_date,
open=float(row.get("open", 0)),
high=float(row.get("high", 0)),
low=float(row.get("low", 0)),
close=float(row.get("close", 0)),
volume=int(row.get("volume", 0)),
amount=float(row.get("amount", 0))
)
self.db.add(record)
self.db.commit()
def _save_min_kline(self, code: str, df: pd.DataFrame, period: str):
"""保存分钟线数据到数据库"""
if df.empty:
return
from datetime import datetime
for idx, row in df.iterrows():
kline_time = row.get("kline_time")
if kline_time is None:
continue
trade_datetime = kline_time if isinstance(kline_time, datetime) else datetime.fromisoformat(str(kline_time))
existing = self.db.query(StockKlineMin).filter(
and_(
StockKlineMin.code == code,
StockKlineMin.period_type == period,
StockKlineMin.trade_datetime == trade_datetime
)
).first()
if not existing:
record = StockKlineMin(
code=code,
period_type=period,
trade_datetime=trade_datetime,
open=float(row.get("open", 0)),
high=float(row.get("high", 0)),
low=float(row.get("low", 0)),
close=float(row.get("close", 0)),
volume=int(row.get("volume", 0)),
amount=float(row.get("amount", 0))
)
self.db.add(record)
self.db.commit()
def get_kline_chart_data(
self,
code: str,
start_date: date,
end_date: date,
period: str = "daily"
) -> dict:
"""
获取K线图数据ECharts格式
Returns:
{
"categoryData": ["2024-01-02", ...],
"values": [[open, close, low, high, volume], ...],
"volumes": [[index, volume, sign], ...]
}
"""
kline_data = self.get_kline([code], start_date, end_date, period)
data = kline_data.get(code, [])
if not data:
return {
"categoryData": [],
"values": [],
"volumes": []
}
category_data = []
values = []
volumes = []
for i, item in enumerate(data):
date_key = item.get("trade_date") or item.get("trade_datetime", "")[:10]
category_data.append(date_key)
# ECharts candlestick format: [open, close, low, high]
values.append([
item["open"],
item["close"],
item["low"],
item["high"],
item["volume"]
])
# Volume with color sign
sign = 1 if item["close"] >= item["open"] else -1
volumes.append([i, item["volume"], sign])
return {
"categoryData": category_data,
"values": values,
"volumes": volumes
}
def get_cache_status(self, code: str, period: str = "daily") -> dict:
"""获取代码缓存状态"""
if period == "daily":
query = self.db.query(StockKlineDaily).filter(StockKlineDaily.code == code)
count = query.count()
min_date = query.order_by(StockKlineDaily.trade_date).first()
max_date = query.order_by(StockKlineDaily.trade_date.desc()).first()
return {
"code": code,
"security_type": "stock",
"period_type": period,
"record_count": count,
"min_date": format_date(min_date.trade_date) if min_date else None,
"max_date": format_date(max_date.trade_date) if max_date else None
}
else:
query = self.db.query(StockKlineMin).filter(
StockKlineMin.code == code,
StockKlineMin.period_type == period
)
count = query.count()
return {
"code": code,
"security_type": "stock",
"period_type": period,
"record_count": count,
"min_date": None,
"max_date": None
}