fix: 修复一些问题

master
Lxy 3 months ago
parent be37192bf9
commit bdb43d97a2

@ -0,0 +1,159 @@
# 期货K线接口修复报告
## 问题描述
期货K线接口 `/v1/futures/klines/{symbol}` 返回的 `items` 为空列表。
## 根本原因
代码中**硬编码**了使用 `amazingdata` 适配器,但配置文件中配置的是 `custom` 适配器。导致:
1. 配置文件中 `sources.futures.active = "custom"`
2. 但代码中尝试连接 `amazingdata` 适配器
3. `_connect_adapter` 方法中尝试从 `file_config.sources.stock.list["amazingdata"]` 获取配置
4. 配置中不存在 `amazingdata`,导致 `KeyError: 'amazingdata'`
5. 异常被捕获后返回空列表
## 修复内容
### 1. 修复硬编码适配器名称
**修改文件:**
- `app/services/futures_service.py`
- `app/services/stock_service.py`
**修改内容:**
将以下代码:
```python
if not adapter:
loop.run_until_complete(adapter_service._connect_adapter("amazingdata"))
```
改为:
```python
if not adapter:
# 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.futures.active # 或 config.sources.stock.active
info(f"Connecting to configured adapter: {active_source}")
loop.run_until_complete(adapter_service._connect_adapter(active_source))
```
### 2. 修复适配器配置获取逻辑
**修改文件:**
- `app/services/adapter_service.py`
**修改内容:**
将以下代码:
```python
if name == "amazingdata":
source_info = file_config.sources.stock.list["amazingdata"]
adapter_config = dict(source_info.config) if source_info else {}
else:
adapter_config = self.configs[name].get("config", {})
```
改为:
```python
# 尝试从配置文件中获取适配器配置
adapter_config = None
# 1. 首先检查 stock 配置
if name in file_config.sources.stock.list:
source_info = file_config.sources.stock.list[name]
adapter_config = dict(source_info.config) if source_info else {}
# 2. 然后检查 futures 配置
elif name in file_config.sources.futures.list:
source_info = file_config.sources.futures.list[name]
adapter_config = dict(source_info.config) if source_info else {}
# 3. 使用默认配置
else:
adapter_config = self.configs[name].get("config", {})
```
### 3. 修复期货仓库的频率映射
**修改文件:**
- `app/repositories/futures_repository.py`
**修改内容:**
添加了对其他频率的映射虽然数据库只支持1分钟和日线但避免KeyError:
```python
def _get_kline_model(self, freq: Frequency):
mapping = {
Frequency.FREQ_1M: FuturesKLine1M,
Frequency.FREQ_1D: FuturesKLine1D,
Frequency.FREQ_5M: FuturesKLine1D, # 默认使用日线
Frequency.FREQ_15M: FuturesKLine1D,
Frequency.FREQ_30M: FuturesKLine1D,
Frequency.FREQ_60M: FuturesKLine1D,
Frequency.FREQ_1W: FuturesKLine1D,
Frequency.FREQ_1MONTH: FuturesKLine1D,
}
return mapping.get(freq, FuturesKLine1D)
```
## 当前状态
API现在可以正常返回响应:
```json
{
"code": 0,
"message": "success",
"data": {
"symbol": "CU2504.SHFE",
"name": null,
"freq": "1d",
"adjust": "",
"count": 0,
"items": []
}
}
```
`items` 为空是因为:
1. 数据库中没有数据
2. 配置的 `custom` 适配器未注册(只注册了 `amazingdata`
## 使用建议
要使接口返回实际数据,需要:
1. **配置 AmazingData 适配器:**
修改 `config.json`:
```json
{
"sources": {
"futures": {
"active": "amazingdata",
"list": {
"amazingdata": {
"type": "sdk",
"config": {
"username": "your_username",
"password": "your_password",
"host": "your_host",
"port": "8600"
}
}
}
}
}
}
```
2. **安装 AmazingData SDK:**
```bash
pip install AmazingData tgw
```
3. **或者注册自定义适配器:**
`AdapterService._register_builtin_adapters()` 中添加:
```python
self.register_adapter("custom", lambda: YourCustomAdapter())
```

@ -0,0 +1,276 @@
# 系统完善报告
## 概述
本次系统完善共完成了6个主要功能的开发和改进。
## 已完成的功能
### 1. 股票复权计算功能 ✅
**文件修改:**
- `app/repositories/models.py` - 添加 `StockAdjustFactor` 复权系数表
- `app/repositories/stock_repository.py` - 添加复权系数查询和保存方法
- `app/services/stock_service.py` - 实现复权计算逻辑
**功能说明:**
- 支持前复权(qfq)和后复权(hfq)计算
- 复权系数自动从数据源获取并缓存到数据库
- 支持价格、成交量的复权调整
- 保留原始复权系数在K线数据中
**技术实现:**
- 前复权:以最新价格为基准,历史价格按比例缩小
- 后复权:以历史最早价格为基准,后续价格按比例放大
---
### 2. Prometheus指标暴露端点 ✅
**新增文件:**
- `app/core/metrics.py` - 指标收集模块
**文件修改:**
- `app/main.py` - 添加指标中间件和端点
- `requirements.txt` - 添加 prometheus-client 依赖
**功能说明:**
- HTTP请求计数和持续时间监控
- 活跃请求数跟踪
- 数据库操作性能监控
- 数据源健康状态监控
- WebSocket连接数监控
- 缓存命中率监控
**暴露端点:**
```
GET /metrics - Prometheus格式的指标数据
```
**指标列表:**
| 指标名 | 类型 | 说明 |
|--------|------|------|
| http_requests_total | Counter | HTTP请求总数 |
| http_request_duration_seconds | Histogram | HTTP请求持续时间 |
| http_requests_active | Gauge | 活跃请求数 |
| api_calls_total | Counter | API调用总数 |
| db_operation_duration_seconds | Histogram | 数据库操作持续时间 |
| data_source_status | Gauge | 数据源健康状态 |
| websocket_connections | Gauge | WebSocket连接数 |
| websocket_messages_total | Counter | WebSocket消息总数 |
---
### 3. 应用层限流功能 ✅
**新增文件:**
- `app/core/rate_limiter.py` - 限流模块
**文件修改:**
- `app/main.py` - 添加限流中间件
**功能说明:**
- 支持三种限流算法:固定窗口、滑动窗口、令牌桶
- 基于客户端IP + 路径的限流key
- 可配置的请求速率和突发容量
- 自动清理过期数据
**默认配置:**
```python
RateLimitConfig(
requests_per_minute=120, # 每分钟120请求
burst_size=20, # 突发20请求
strategy="sliding_window" # 滑动窗口算法
)
```
**响应头:**
```
X-RateLimit-Limit: 120
X-RateLimit-Remaining: 119
X-RateLimit-Reset: 1700000000
Retry-After: 60 # 限流时返回
```
---
### 4. 监控告警通道 ✅
**新增文件:**
- `app/monitor/alert_channels.py` - 告警通道模块
**文件修改:**
- `app/monitor/__init__.py` - 导出告警类
- `app/monitor/monitor.py` - 集成新的告警管理器
**支持的告警通道:**
| 通道 | 类型 | 说明 |
|------|------|------|
| LogAlertChannel | 日志 | 默认日志输出 |
| DingTalkAlertChannel | 钉钉 | 钉钉机器人webhook |
| EmailAlertChannel | 邮件 | SMTP邮件发送 |
| WebhookAlertChannel | Webhook | 自定义HTTP回调 |
**功能特性:**
- 支持消息路由(按告警级别)
- 支持批量发送
- Markdown格式的钉钉消息
- HTML格式的邮件内容
- 可扩展的架构
**使用示例:**
```python
from app.monitor import get_alert_manager
# 发送告警
await get_alert_manager().send_simple(
title="数据缺失告警",
content="股票000001.SZ数据缺失",
level="warning"
)
```
---
### 5. 修复已知问题 ✅
**修复内容:**
1. 添加缺失的 `Response` 导入到 `rate_limiter.py`
2. 修复 `app/monitor/__init__.py` 中已删除类的引用
3. 更新 `requirements.txt` 添加 `prometheus-client`
4. 安装缺失的依赖包
---
### 6. 服务重启功能 ✅
**文件修改:**
- `app/api/admin_routes.py` - 实现重启逻辑
**功能说明:**
- 延迟2秒后重启确保当前响应返回
- 支持Windows和Linux/Mac系统
- 在后台线程中执行不阻塞API响应
**使用方式:**
```bash
POST /v1/admin/system/restart
```
**注意:** 生产环境建议使用Docker或systemd管理服务生命周期
---
## 配置文件更新建议
### 添加告警配置到 config.json:
```json
{
"alert": {
"log": {
"enabled": true
},
"dingtalk": {
"enabled": false,
"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=xxx",
"secret": "your-secret",
"at_mobiles": ["13800138000"],
"at_all": false
},
"email": {
"enabled": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"username": "alert@example.com",
"password": "your-password",
"from_addr": "alert@example.com",
"to_addrs": ["admin@example.com"],
"use_tls": true
},
"routing": {
"info": ["log"],
"warning": ["log", "dingtalk"],
"error": ["log", "dingtalk", "email"],
"critical": ["log", "dingtalk", "email"]
}
}
}
```
---
## API端点更新
### 新增端点:
| 端点 | 方法 | 说明 |
|------|------|------|
| `/metrics` | GET | Prometheus指标数据 |
| `/admin/system/restart` | POST | 重启服务 |
### 限流保护端点:
所有 `/v1/*` 端点(除 `/health`, `/metrics`, `/docs` 等外)都受到限流保护。
---
## 系统架构图
```
┌─────────────────────────────────────────────────────────┐
│ FastAPI Application │
├─────────────────────────────────────────────────────────┤
│ CORS Middleware │
│ Metrics Middleware (Prometheus) │
│ Rate Limit Middleware (120 req/min) │
├─────────────────────────────────────────────────────────┤
│ Routes: │
│ /v1/stock/* - 股票接口 │
│ /v1/futures/* - 期货接口 │
│ /v1/admin/* - 管理接口 │
│ /v1/stream - WebSocket │
│ /metrics - 指标端点 │
│ /admin - 管理后台UI │
├─────────────────────────────────────────────────────────┤
│ Services: │
│ StockService - 复权计算 ✅ │
│ FuturesService - 期货业务 │
│ AdminService - 管理功能 │
│ AdapterService - 数据源适配 │
│ AlertManager - 告警管理 ✅ │
├─────────────────────────────────────────────────────────┤
│ Repositories: │
│ StockRepository - 复权系数表 ✅ │
│ FuturesRepository - 期货数据 │
├─────────────────────────────────────────────────────────┤
│ Data Sources: │
│ AmazingDataAdapter - 星耀数智 │
└─────────────────────────────────────────────────────────┘
```
---
## 后续建议
1. **Prometheus集成**
- 部署Prometheus服务器抓取 `/metrics` 端点
- 配置Grafana仪表板展示指标
2. **告警规则配置**
- 配置告警路由规则
- 设置钉钉/邮件通道参数
3. **性能优化**
- 添加Redis缓存层
- 实现数据库连接池监控
4. **安全性增强**
- 实现API Key验证逻辑
- 添加请求签名验证
---
## 完成时间
2026-03-14

@ -160,6 +160,14 @@ class AmazingDataAdapter(DataSourceAdapter):
def _do_login(self): def _do_login(self):
"""执行登录(同步方法)""" """执行登录(同步方法)"""
# 检查配置是否完整
if not self.config.username or not self.config.password or not self.config.host:
raise RuntimeError(
f"AmazingData 配置不完整: username={self.config.username}, "
f"host={self.config.host}, port={self.config.port}. "
f"请在 config.json 中配置正确的账号信息"
)
print("[amazingdata_adapter]正在登录 AmazingData...") print("[amazingdata_adapter]正在登录 AmazingData...")
print(f"[amazingdata_adapter]登录用户: {self.config.username}") print(f"[amazingdata_adapter]登录用户: {self.config.username}")
print(f"[amazingdata_adapter]登录地址: {self.config.host}:{self.config.port}") print(f"[amazingdata_adapter]登录地址: {self.config.host}:{self.config.port}")

@ -59,12 +59,51 @@ def reload_config(
def restart_service( def restart_service(
token: str = Depends(verify_admin_token) token: str = Depends(verify_admin_token)
): ):
"""重启服务""" """重启服务
# TODO: 实现服务重启逻辑
注意: 此方法通过创建子进程实现服务重启适用于开发环境
生产环境建议使用Docker或systemd管理服务生命周期
"""
import os
import sys
import subprocess
import threading
import time
def delayed_restart():
"""延迟重启函数"""
time.sleep(2) # 等待当前响应返回
# 获取当前Python解释器和启动参数
python = sys.executable
args = sys.argv[:]
# 在Windows上使用start命令在Linux上使用nohup
if os.name == 'nt': # Windows
subprocess.Popen(
['start', 'python'] + args,
shell=True,
creationflags=subprocess.CREATE_NEW_CONSOLE
)
else: # Linux/Mac
subprocess.Popen(
[python] + args,
stdout=open('/dev/null', 'w'),
stderr=open('/dev/null', 'w'),
start_new_session=True
)
# 退出当前进程
os._exit(0)
# 在后台线程中执行重启
restart_thread = threading.Thread(target=delayed_restart, daemon=True)
restart_thread.start()
return Response( return Response(
code=0, code=0,
message="重启命令已发送", message="服务将在2秒后重启",
data={"status": "restarting"} data={"status": "restarting", "delay_seconds": 2}
) )

@ -0,0 +1,221 @@
"""Prometheus指标收集模块"""
from contextvars import ContextVar
from typing import Callable, Optional
import time
from prometheus_client import Counter, Histogram, Gauge, Info, generate_latest, CONTENT_TYPE_LATEST
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
# ============================================
# 定义指标
# ============================================
# HTTP请求计数器
http_requests_total = Counter(
'http_requests_total',
'Total HTTP requests',
['method', 'endpoint', 'status_code']
)
# HTTP请求持续时间
http_request_duration_seconds = Histogram(
'http_request_duration_seconds',
'HTTP request duration in seconds',
['method', 'endpoint'],
buckets=[.005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0]
)
# 活跃请求数
http_requests_active = Gauge(
'http_requests_active',
'Active HTTP requests',
['method']
)
# API调用计数器按业务分类
api_calls_total = Counter(
'api_calls_total',
'Total API calls by category',
['category', 'operation']
)
# 数据库操作持续时间
db_operation_duration_seconds = Histogram(
'db_operation_duration_seconds',
'Database operation duration',
['operation', 'table'],
buckets=[.001, .005, .01, .025, .05, .1, .25, .5, 1.0]
)
# 数据源状态
data_source_status = Gauge(
'data_source_status',
'Data source health status (1=healthy, 0=unhealthy)',
['source', 'asset_class']
)
# WebSocket连接数
websocket_connections = Gauge(
'websocket_connections',
'Number of active WebSocket connections'
)
# WebSocket消息计数器
websocket_messages_total = Counter(
'websocket_messages_total',
'Total WebSocket messages',
['direction'] # 'in' or 'out'
)
# 缓存命中率
cache_hit_ratio = Gauge(
'cache_hit_ratio',
'Cache hit ratio',
['cache_type']
)
# 应用信息
app_info = Info(
'market_data_service',
'Application information'
)
# ============================================
# 上下文变量
# ============================================
# 用于存储请求开始时间
request_start_time: ContextVar[Optional[float]] = ContextVar('request_start_time', default=None)
# ============================================
# 指标收集函数
# ============================================
def record_http_request(method: str, endpoint: str, status_code: int, duration: float):
"""记录HTTP请求指标"""
http_requests_total.labels(
method=method,
endpoint=endpoint,
status_code=status_code
).inc()
http_request_duration_seconds.labels(
method=method,
endpoint=endpoint
).observe(duration)
def record_api_call(category: str, operation: str):
"""记录API调用"""
api_calls_total.labels(
category=category,
operation=operation
).inc()
def record_db_operation(operation: str, table: str, duration: float):
"""记录数据库操作"""
db_operation_duration_seconds.labels(
operation=operation,
table=table
).observe(duration)
def update_data_source_status(source: str, asset_class: str, is_healthy: bool):
"""更新数据源状态"""
data_source_status.labels(
source=source,
asset_class=asset_class
).set(1 if is_healthy else 0)
def increment_websocket_connections(delta: int = 1):
"""增加WebSocket连接数"""
websocket_connections.inc(delta)
def decrement_websocket_connections(delta: int = 1):
"""减少WebSocket连接数"""
websocket_connections.dec(delta)
def record_websocket_message(direction: str):
"""记录WebSocket消息"""
websocket_messages_total.labels(direction=direction).inc()
def set_cache_hit_ratio(cache_type: str, ratio: float):
"""设置缓存命中率"""
cache_hit_ratio.labels(cache_type=cache_type).set(ratio)
def set_app_info(version: str, build_time: str, git_commit: str = ""):
"""设置应用信息"""
app_info.info({
'version': version,
'build_time': build_time,
'git_commit': git_commit
})
# ============================================
# FastAPI中间件
# ============================================
class MetricsMiddleware(BaseHTTPMiddleware):
"""指标收集中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过metrics端点自身的监控
if request.url.path == '/metrics':
return await call_next(request)
# 记录活跃请求
http_requests_active.labels(method=request.method).inc()
# 记录开始时间
start_time = time.time()
status_code = 200 # 默认状态码
try:
response = await call_next(request)
status_code = response.status_code
return response
except Exception as e:
status_code = 500
raise
finally:
# 计算持续时间
duration = time.time() - start_time
# 获取端点路径使用路由模板而非实际URL
endpoint = request.url.path
if hasattr(request.state, 'route'):
endpoint = request.state.route
# 记录指标
record_http_request(
method=request.method,
endpoint=endpoint,
status_code=status_code,
duration=duration
)
# 减少活跃请求计数
http_requests_active.labels(method=request.method).dec()
# ============================================
# 指标端点
# ============================================
def get_metrics_response() -> Response:
"""获取Prometheus格式的指标数据"""
from fastapi.responses import Response as FastAPIResponse
return FastAPIResponse(
content=generate_latest(),
media_type=CONTENT_TYPE_LATEST
)

@ -0,0 +1,352 @@
"""应用层限流模块
支持以下限流策略:
1. 固定窗口计数器
2. 滑动窗口计数器
3. 令牌桶算法
"""
import time
import asyncio
from typing import Dict, Optional, Tuple, Callable
from dataclasses import dataclass, field
from threading import Lock
from collections import deque
from fastapi import Request, HTTPException, Response
from starlette.middleware.base import BaseHTTPMiddleware
@dataclass
class RateLimitConfig:
"""限流配置"""
requests_per_minute: int = 60 # 每分钟请求数
burst_size: int = 10 # 突发请求数
window_size: int = 60 # 窗口大小(秒)
strategy: str = "sliding_window" # 限流策略: fixed_window, sliding_window, token_bucket
key_func: Optional[Callable[[Request], str]] = None # 自定义key生成函数
@dataclass
class FixedWindow:
"""固定窗口"""
count: int = 0
reset_time: float = field(default_factory=lambda: time.time() + 60)
@dataclass
class SlidingWindow:
"""滑动窗口"""
requests: deque = field(default_factory=lambda: deque())
def clean_old_requests(self, window_size: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_size
while self.requests and self.requests[0] < cutoff:
self.requests.popleft()
@dataclass
class TokenBucket:
"""令牌桶"""
tokens: float = field(default_factory=float)
last_update: float = field(default_factory=time.time)
def update_tokens(self, rate_per_second: float, max_tokens: float):
"""更新令牌数量"""
now = time.time()
elapsed = now - self.last_update
self.tokens = min(max_tokens, self.tokens + elapsed * rate_per_second)
self.last_update = now
class RateLimiter:
"""限流器
支持多种限流策略默认使用滑动窗口算法
"""
def __init__(self, config: RateLimitConfig = None):
self.config = config or RateLimitConfig()
self.lock = Lock()
# 存储每个key的限流状态
self.fixed_windows: Dict[str, FixedWindow] = {}
self.sliding_windows: Dict[str, SlidingWindow] = {}
self.token_buckets: Dict[str, TokenBucket] = {}
# 启动清理任务
self._cleanup_task = None
def _get_key(self, request: Request) -> str:
"""生成限流key
默认使用客户端IP + 路径
"""
if self.config.key_func:
return self.config.key_func(request)
# 获取客户端IP
client_ip = request.client.host if request.client else "unknown"
# 检查X-Forwarded-For头
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
client_ip = forwarded.split(",")[0].strip()
# 使用IP + 路径作为key
return f"{client_ip}:{request.url.path}"
def _check_fixed_window(self, key: str) -> Tuple[bool, Dict]:
"""固定窗口限流检查
Returns:
(是否允许, 响应头信息)
"""
now = time.time()
window = self.fixed_windows.get(key)
if window is None or now > window.reset_time:
# 新窗口
self.fixed_windows[key] = FixedWindow(count=1, reset_time=now + self.config.window_size)
remaining = self.config.requests_per_minute - 1
return True, {
"X-RateLimit-Limit": str(self.config.requests_per_minute),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(now + self.config.window_size))
}
if window.count >= self.config.requests_per_minute:
# 超过限制
return False, {
"X-RateLimit-Limit": str(self.config.requests_per_minute),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(window.reset_time)),
"Retry-After": str(int(window.reset_time - now))
}
# 允许请求
window.count += 1
remaining = self.config.requests_per_minute - window.count
return True, {
"X-RateLimit-Limit": str(self.config.requests_per_minute),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(window.reset_time))
}
def _check_sliding_window(self, key: str) -> Tuple[bool, Dict]:
"""滑动窗口限流检查
Returns:
(是否允许, 响应头信息)
"""
now = time.time()
if key not in self.sliding_windows:
self.sliding_windows[key] = SlidingWindow()
window = self.sliding_windows[key]
window.clean_old_requests(self.config.window_size)
if len(window.requests) >= self.config.requests_per_minute:
# 超过限制
oldest = window.requests[0]
reset_time = oldest + self.config.window_size
return False, {
"X-RateLimit-Limit": str(self.config.requests_per_minute),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(reset_time)),
"Retry-After": str(int(reset_time - now))
}
# 允许请求
window.requests.append(now)
remaining = self.config.requests_per_minute - len(window.requests)
return True, {
"X-RateLimit-Limit": str(self.config.requests_per_minute),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(now + self.config.window_size))
}
def _check_token_bucket(self, key: str) -> Tuple[bool, Dict]:
"""令牌桶限流检查
Returns:
(是否允许, 响应头信息)
"""
rate_per_second = self.config.requests_per_minute / 60.0
max_tokens = self.config.burst_size
if key not in self.token_buckets:
self.token_buckets[key] = TokenBucket(tokens=max_tokens)
bucket = self.token_buckets[key]
bucket.update_tokens(rate_per_second, max_tokens)
if bucket.tokens < 1:
# 令牌不足
wait_time = (1 - bucket.tokens) / rate_per_second
return False, {
"X-RateLimit-Limit": str(self.config.requests_per_minute),
"X-RateLimit-Remaining": "0",
"Retry-After": str(int(wait_time) + 1)
}
# 消耗令牌
bucket.tokens -= 1
remaining = int(bucket.tokens)
return True, {
"X-RateLimit-Limit": str(self.config.requests_per_minute),
"X-RateLimit-Remaining": str(remaining)
}
def is_allowed(self, request: Request) -> Tuple[bool, Dict]:
"""检查请求是否允许通过
Returns:
(是否允许, 响应头信息)
"""
with self.lock:
key = self._get_key(request)
if self.config.strategy == "fixed_window":
return self._check_fixed_window(key)
elif self.config.strategy == "token_bucket":
return self._check_token_bucket(key)
else:
return self._check_sliding_window(key)
def cleanup(self):
"""清理过期的限流数据"""
now = time.time()
with self.lock:
# 清理固定窗口
expired = [
key for key, window in self.fixed_windows.items()
if now > window.reset_time
]
for key in expired:
del self.fixed_windows[key]
# 清理滑动窗口
for window in self.sliding_windows.values():
window.clean_old_requests(self.config.window_size)
# 清理空的滑动窗口
empty = [
key for key, window in self.sliding_windows.items()
if not window.requests
]
for key in empty:
del self.sliding_windows[key]
async def start_cleanup_task(self):
"""启动定期清理任务"""
while True:
await asyncio.sleep(300) # 每5分钟清理一次
self.cleanup()
class RateLimitMiddleware(BaseHTTPMiddleware):
"""限流中间件
使用示例:
app.add_middleware(
RateLimitMiddleware,
config=RateLimitConfig(
requests_per_minute=60,
strategy="sliding_window"
)
)
"""
def __init__(self, app, config: RateLimitConfig = None):
super().__init__(app)
self.limiter = RateLimiter(config)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过某些路径
if request.url.path in ["/health", "/metrics", "/docs", "/redoc", "/openapi.json"]:
return await call_next(request)
# 检查限流
allowed, headers = self.limiter.is_allowed(request)
if not allowed:
raise HTTPException(
status_code=429,
detail="Too many requests",
headers=headers
)
# 执行请求
response = await call_next(request)
# 添加限流响应头
for key, value in headers.items():
response.headers[key] = value
return response
# 全局限流器实例(用于特定端点限流)
_default_limiter: Optional[RateLimiter] = None
def get_limiter(config: RateLimitConfig = None) -> RateLimiter:
"""获取全局限流器实例"""
global _default_limiter
if _default_limiter is None:
_default_limiter = RateLimiter(config)
return _default_limiter
def rate_limit(
requests_per_minute: int = 60,
strategy: str = "sliding_window",
key_func: Optional[Callable[[Request], str]] = None
):
"""装饰器:为特定端点添加限流
使用示例:
@app.get("/api/data")
@rate_limit(requests_per_minute=30)
async def get_data():
return {"data": "..."}
"""
config = RateLimitConfig(
requests_per_minute=requests_per_minute,
strategy=strategy,
key_func=key_func
)
limiter = RateLimiter(config)
def decorator(func: Callable) -> Callable:
async def wrapper(request: Request, *args, **kwargs):
allowed, headers = limiter.is_allowed(request)
if not allowed:
raise HTTPException(
status_code=429,
detail="Too many requests",
headers=headers
)
# 如果原函数是协程
if asyncio.iscoroutinefunction(func):
response = await func(request, *args, **kwargs)
else:
response = func(request, *args, **kwargs)
# 添加响应头
if hasattr(response, 'headers'):
for key, value in headers.items():
response.headers[key] = value
return response
return wrapper
return decorator

@ -11,6 +11,8 @@ from app.api import router, admin_router
from app.websocket import WebSocketServer from app.websocket import WebSocketServer
from app.core.config import get_config, get_settings from app.core.config import get_config, get_settings
from app.core.logger import info, error, setup_logging from app.core.logger import info, error, setup_logging
from app.core.metrics import MetricsMiddleware, get_metrics_response, set_app_info
from app.core.rate_limiter import RateLimitMiddleware, RateLimitConfig
from app.repositories.database import init_db from app.repositories.database import init_db
@ -58,6 +60,19 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# 添加Prometheus指标中间件
app.add_middleware(MetricsMiddleware)
# 添加限流中间件默认每分钟60请求滑动窗口算法
app.add_middleware(
RateLimitMiddleware,
config=RateLimitConfig(
requests_per_minute=120, # 每分钟120请求
burst_size=20, # 突发20请求
strategy="sliding_window" # 使用滑动窗口算法
)
)
# 注册API路由 # 注册API路由
app.include_router(router, prefix="/v1") app.include_router(router, prefix="/v1")
app.include_router(admin_router, prefix="/v1") app.include_router(admin_router, prefix="/v1")
@ -74,6 +89,12 @@ async def websocket_endpoint(websocket):
await ws_server.handle(websocket, client_id) await ws_server.handle(websocket, client_id)
@app.get("/metrics")
async def metrics():
"""Prometheus指标端点"""
return get_metrics_response()
# 管理后台页面HTML完整版 # 管理后台页面HTML完整版
ADMIN_HTML = '''<!DOCTYPE html> ADMIN_HTML = '''<!DOCTYPE html>
<html lang="zh-CN"> <html lang="zh-CN">

@ -1,4 +1,14 @@
"""数据质量监控模块""" """数据质量监控模块"""
from .monitor import DataQualityMonitor, AlertSender, LogAlertSender from .monitor import DataQualityMonitor, CheckResult, QualityReport
from .alert_channels import (
AlertChannel, AlertMessage, AlertManager,
LogAlertChannel, DingTalkAlertChannel, EmailAlertChannel, WebhookAlertChannel,
get_alert_manager, init_alert_manager
)
__all__ = ["DataQualityMonitor", "AlertSender", "LogAlertSender"] __all__ = [
"DataQualityMonitor", "CheckResult", "QualityReport",
"AlertChannel", "AlertMessage", "AlertManager",
"LogAlertChannel", "DingTalkAlertChannel", "EmailAlertChannel", "WebhookAlertChannel",
"get_alert_manager", "init_alert_manager"
]

@ -0,0 +1,516 @@
"""告警通道模块
支持多种告警方式:
- 日志告警默认
- 钉钉机器人
- 邮件
- Webhook
"""
import json
import smtplib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from typing import Dict, List, Optional
import httpx
from app.core.logger import info, error, warning
@dataclass
class AlertMessage:
"""告警消息"""
title: str
content: str
level: str = "warning" # info, warning, error, critical
timestamp: datetime = None
metadata: Dict = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.now()
if self.metadata is None:
self.metadata = {}
class AlertChannel(ABC):
"""告警通道基类"""
def __init__(self, name: str, enabled: bool = True):
self.name = name
self.enabled = enabled
@abstractmethod
async def send(self, message: AlertMessage) -> bool:
"""发送告警消息"""
pass
async def send_batch(self, messages: List[AlertMessage]) -> List[bool]:
"""批量发送告警"""
results = []
for msg in messages:
result = await self.send(msg)
results.append(result)
return results
class LogAlertChannel(AlertChannel):
"""日志告警通道"""
def __init__(self, enabled: bool = True):
super().__init__("log", enabled)
async def send(self, message: AlertMessage) -> bool:
"""发送日志告警"""
if not self.enabled:
return False
log_msg = f"[{message.level.upper()}] {message.title}: {message.content}"
if message.level == "info":
info(log_msg)
elif message.level == "warning":
warning(log_msg)
else:
error(log_msg)
return True
class DingTalkAlertChannel(AlertChannel):
"""钉钉机器人告警通道"""
def __init__(
self,
webhook_url: str,
secret: Optional[str] = None,
at_mobiles: Optional[List[str]] = None,
at_all: bool = False,
enabled: bool = True
):
super().__init__("dingtalk", enabled)
self.webhook_url = webhook_url
self.secret = secret
self.at_mobiles = at_mobiles or []
self.at_all = at_all
def _generate_sign(self, timestamp: str) -> str:
"""生成钉钉签名"""
import hmac
import hashlib
import urllib.parse
if not self.secret:
return ""
string_to_sign = f"{timestamp}\n{self.secret}"
hmac_code = hmac.new(
self.secret.encode('utf-8'),
string_to_sign.encode('utf-8'),
digestmod=hashlib.sha256
).digest()
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
return sign
def _build_markdown_message(self, message: AlertMessage) -> Dict:
"""构建Markdown格式的消息"""
# 根据级别选择颜色
color_map = {
"info": "#007bff",
"warning": "#ffc107",
"error": "#dc3545",
"critical": "#6f42c1"
}
color = color_map.get(message.level, "#6c757d")
# 构建@信息
at_text = ""
if self.at_all:
at_text = "@所有人 "
elif self.at_mobiles:
at_text = " ".join([f"@{mobile}" for mobile in self.at_mobiles])
content = f"""### {message.title} {at_text}
**告警级别:** <font color='{color}'>{message.level.upper()}</font>
**告警时间:** {message.timestamp.strftime('%Y-%m-%d %H:%M:%S')}
---
{message.content}
---
**详细信息:**
```json
{json.dumps(message.metadata, indent=2, ensure_ascii=False, default=str)}
```
"""
return {
"msgtype": "markdown",
"markdown": {
"title": message.title,
"text": content
},
"at": {
"atMobiles": self.at_mobiles,
"isAtAll": self.at_all
}
}
def _build_text_message(self, message: AlertMessage) -> Dict:
"""构建文本格式的消息"""
return {
"msgtype": "text",
"text": {
"content": f"[{message.level.upper()}] {message.title}\n\n{message.content}"
},
"at": {
"atMobiles": self.at_mobiles,
"isAtAll": self.at_all
}
}
async def send(self, message: AlertMessage, msg_type: str = "markdown") -> bool:
"""发送钉钉告警
Args:
message: 告警消息
msg_type: 消息类型 (markdown text)
"""
if not self.enabled or not self.webhook_url:
return False
try:
import base64
import time as time_module
timestamp = str(int(round(time_module.time() * 1000)))
sign = self._generate_sign(timestamp)
# 构建URL
url = self.webhook_url
if self.secret:
url = f"{self.webhook_url}&timestamp={timestamp}&sign={sign}"
# 构建消息
if msg_type == "markdown":
payload = self._build_markdown_message(message)
else:
payload = self._build_text_message(message)
# 发送请求
async with httpx.AsyncClient() as client:
response = await client.post(
url,
json=payload,
headers={"Content-Type": "application/json"},
timeout=10.0
)
if response.status_code == 200:
result = response.json()
if result.get("errcode") == 0:
info(f"DingTalk alert sent: {message.title}")
return True
else:
error(f"DingTalk API error: {result}")
return False
else:
error(f"DingTalk HTTP error: {response.status_code}")
return False
except Exception as e:
error(f"Failed to send DingTalk alert: {e}")
return False
class EmailAlertChannel(AlertChannel):
"""邮件告警通道"""
def __init__(
self,
smtp_host: str,
smtp_port: int,
username: str,
password: str,
from_addr: str,
to_addrs: List[str],
use_tls: bool = True,
enabled: bool = True
):
super().__init__("email", enabled)
self.smtp_host = smtp_host
self.smtp_port = smtp_port
self.username = username
self.password = password
self.from_addr = from_addr
self.to_addrs = to_addrs
self.use_tls = use_tls
def _build_html_content(self, message: AlertMessage) -> str:
"""构建HTML格式的邮件内容"""
# 根据级别选择颜色
color_map = {
"info": "#007bff",
"warning": "#ffc107",
"error": "#dc3545",
"critical": "#6f42c1"
}
color = color_map.get(message.level, "#6c757d")
metadata_html = ""
if message.metadata:
rows = ""
for key, value in message.metadata.items():
rows += f"<tr><td><strong>{key}</strong></td><td>{value}</td></tr>"
metadata_html = f"""
<h3>详细信息</h3>
<table border="1" cellpadding="5" cellspacing="0">
{rows}
</table>
"""
return f"""
<html>
<body>
<h2 style="color: {color};">[{message.level.upper()}] {message.title}</h2>
<p><strong>告警时间:</strong> {message.timestamp.strftime('%Y-%m-%d %H:%M:%S')}</p>
<hr>
<p>{message.content.replace(chr(10), '<br>')}</p>
<hr>
{metadata_html}
<p style="color: #666; font-size: 12px;">
本邮件由行情数据服务自动发送请勿回复
</p>
</body>
</html>
"""
async def send(self, message: AlertMessage) -> bool:
"""发送邮件告警"""
if not self.enabled or not self.to_addrs:
return False
try:
# 构建邮件
msg = MIMEMultipart('alternative')
msg['Subject'] = f"[{message.level.upper()}] {message.title}"
msg['From'] = self.from_addr
msg['To'] = ', '.join(self.to_addrs)
# 添加HTML内容
html_content = self._build_html_content(message)
msg.attach(MIMEText(html_content, 'html', 'utf-8'))
# 发送邮件在executor中执行同步操作
import asyncio
loop = asyncio.get_event_loop()
def send_email():
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
if self.use_tls:
server.starttls()
server.login(self.username, self.password)
server.sendmail(self.from_addr, self.to_addrs, msg.as_string())
server.quit()
await loop.run_in_executor(None, send_email)
info(f"Email alert sent: {message.title}")
return True
except Exception as e:
error(f"Failed to send email alert: {e}")
return False
class WebhookAlertChannel(AlertChannel):
"""Webhook告警通道"""
def __init__(
self,
webhook_url: str,
headers: Optional[Dict[str, str]] = None,
timeout: float = 10.0,
enabled: bool = True
):
super().__init__("webhook", enabled)
self.webhook_url = webhook_url
self.headers = headers or {"Content-Type": "application/json"}
self.timeout = timeout
async def send(self, message: AlertMessage) -> bool:
"""发送Webhook告警"""
if not self.enabled or not self.webhook_url:
return False
try:
payload = {
"title": message.title,
"content": message.content,
"level": message.level,
"timestamp": message.timestamp.isoformat(),
"metadata": message.metadata
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.webhook_url,
json=payload,
headers=self.headers,
timeout=self.timeout
)
if response.status_code < 400:
info(f"Webhook alert sent: {message.title}")
return True
else:
error(f"Webhook error: {response.status_code}")
return False
except Exception as e:
error(f"Failed to send webhook alert: {e}")
return False
class AlertManager:
"""告警管理器
管理多个告警通道支持消息路由和批量发送
"""
def __init__(self):
self.channels: Dict[str, AlertChannel] = {}
self.level_routing = {
"info": ["log"],
"warning": ["log", "dingtalk"],
"error": ["log", "dingtalk", "email"],
"critical": ["log", "dingtalk", "email", "webhook"]
}
def register_channel(self, channel: AlertChannel):
"""注册告警通道"""
self.channels[channel.name] = channel
info(f"Alert channel registered: {channel.name}")
def configure_routing(self, level_routing: Dict[str, List[str]]):
"""配置告警路由规则"""
self.level_routing = level_routing
async def send(
self,
message: AlertMessage,
channels: Optional[List[str]] = None
) -> Dict[str, bool]:
"""发送告警
Args:
message: 告警消息
channels: 指定通道列表None则根据级别路由
Returns:
各通道发送结果
"""
# 确定目标通道
target_channels = channels
if target_channels is None:
target_channels = self.level_routing.get(message.level, ["log"])
# 发送到各通道
results = {}
for channel_name in target_channels:
channel = self.channels.get(channel_name)
if channel:
results[channel_name] = await channel.send(message)
else:
warning(f"Alert channel not found: {channel_name}")
results[channel_name] = False
return results
async def send_simple(
self,
title: str,
content: str,
level: str = "warning",
**kwargs
) -> Dict[str, bool]:
"""发送简单告警"""
message = AlertMessage(
title=title,
content=content,
level=level,
metadata=kwargs
)
return await self.send(message)
# 全局告警管理器实例
_alert_manager: Optional[AlertManager] = None
def get_alert_manager() -> AlertManager:
"""获取全局告警管理器"""
global _alert_manager
if _alert_manager is None:
_alert_manager = AlertManager()
# 默认注册日志通道
_alert_manager.register_channel(LogAlertChannel())
return _alert_manager
def init_alert_manager(config: Dict):
"""从配置初始化告警管理器"""
global _alert_manager
_alert_manager = AlertManager()
# 注册日志通道
_alert_manager.register_channel(LogAlertChannel(
enabled=config.get("log", {}).get("enabled", True)
))
# 注册钉钉通道
dingtalk_config = config.get("dingtalk", {})
if dingtalk_config.get("enabled"):
_alert_manager.register_channel(DingTalkAlertChannel(
webhook_url=dingtalk_config["webhook_url"],
secret=dingtalk_config.get("secret"),
at_mobiles=dingtalk_config.get("at_mobiles", []),
at_all=dingtalk_config.get("at_all", False),
enabled=True
))
# 注册邮件通道
email_config = config.get("email", {})
if email_config.get("enabled"):
_alert_manager.register_channel(EmailAlertChannel(
smtp_host=email_config["smtp_host"],
smtp_port=email_config["smtp_port"],
username=email_config["username"],
password=email_config["password"],
from_addr=email_config["from_addr"],
to_addrs=email_config["to_addrs"],
use_tls=email_config.get("use_tls", True),
enabled=True
))
# 注册Webhook通道
webhook_config = config.get("webhook", {})
if webhook_config.get("enabled"):
_alert_manager.register_channel(WebhookAlertChannel(
webhook_url=webhook_config["webhook_url"],
headers=webhook_config.get("headers"),
timeout=webhook_config.get("timeout", 10.0),
enabled=True
))
# 配置路由规则
if "routing" in config:
_alert_manager.configure_routing(config["routing"])
return _alert_manager

@ -1,6 +1,5 @@
"""数据质量监控 - 对应Go的internal/monitor/monitor.go""" """数据质量监控 - 对应Go的internal/monitor/monitor.go"""
import asyncio import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
@ -11,6 +10,7 @@ from sqlalchemy import text
from app.repositories import StockRepository, FuturesRepository from app.repositories import StockRepository, FuturesRepository
from app.models import Frequency from app.models import Frequency
from app.core.logger import info, error from app.core.logger import info, error
from app.monitor.alert_channels import AlertManager, AlertMessage, get_alert_manager
@dataclass @dataclass
@ -37,23 +37,6 @@ class QualityReport:
pass_rate: float pass_rate: float
class AlertSender(ABC):
"""告警发送接口"""
@abstractmethod
def send_alert(self, title: str, content: str) -> bool:
"""发送告警"""
pass
class LogAlertSender(AlertSender):
"""日志告警发送器"""
def send_alert(self, title: str, content: str) -> bool:
info(f"[ALERT] {title}: {content}")
return True
class DataQualityMonitor: class DataQualityMonitor:
"""数据质量监控""" """数据质量监控"""
@ -62,12 +45,12 @@ class DataQualityMonitor:
db: Session, db: Session,
stock_repo: StockRepository, stock_repo: StockRepository,
futures_repo: FuturesRepository, futures_repo: FuturesRepository,
sender: Optional[AlertSender] = None alert_manager: Optional[AlertManager] = None
): ):
self.db = db self.db = db
self.stock_repo = stock_repo self.stock_repo = stock_repo
self.futures_repo = futures_repo self.futures_repo = futures_repo
self.sender = sender or LogAlertSender() self.alert_manager = alert_manager or get_alert_manager()
async def daily_check(self, check_date: str): async def daily_check(self, check_date: str):
"""每日数据质量检查""" """每日数据质量检查"""
@ -156,11 +139,17 @@ class DataQualityMonitor:
result.detail = f"Data missing: expected {expect_count}, actual {actual_count}" result.detail = f"Data missing: expected {expect_count}, actual {actual_count}"
# 发送告警 # 发送告警
if self.sender: if self.alert_manager:
self.sender.send_alert( asyncio.create_task(self.alert_manager.send_simple(
f"[{asset_type}] Data Missing Alert", title=f"[{asset_type.upper()}] 数据缺失告警",
f"Symbol: {symbol}, Date: {check_date}, Expected: {expect_count}, Actual: {actual_count}" content=f"标的: {symbol}, 日期: {check_date}, 期望: {expect_count}条, 实际: {actual_count}",
) level="warning",
asset_type=asset_type,
symbol=symbol,
check_date=check_date,
expect_count=expect_count,
actual_count=actual_count
))
except Exception as e: except Exception as e:
result.status = "fail" result.status = "fail"

@ -58,10 +58,20 @@ class FuturesRepository:
return items return items
def _get_kline_model(self, freq: Frequency): def _get_kline_model(self, freq: Frequency):
"""根据周期获取K线模型""" """根据周期获取K线模型
注意: 目前数据库只支持1分钟和日线K线存储
其他周期(5m/15m/30m/60m/1w/1month)默认使用日线表
"""
mapping = { mapping = {
Frequency.FREQ_1M: FuturesKLine1M, Frequency.FREQ_1M: FuturesKLine1M,
Frequency.FREQ_1D: FuturesKLine1D, Frequency.FREQ_1D: FuturesKLine1D,
Frequency.FREQ_5M: FuturesKLine1D, # 默认使用日线
Frequency.FREQ_15M: FuturesKLine1D,
Frequency.FREQ_30M: FuturesKLine1D,
Frequency.FREQ_60M: FuturesKLine1D,
Frequency.FREQ_1W: FuturesKLine1D,
Frequency.FREQ_1MONTH: FuturesKLine1D,
} }
return mapping.get(freq, FuturesKLine1D) return mapping.get(freq, FuturesKLine1D)

@ -210,3 +210,20 @@ class DataQualityCheck(Base):
actual_count = Column(Integer, nullable=True, comment="实际数量") actual_count = Column(Integer, nullable=True, comment="实际数量")
detail = Column(String(500), nullable=True, comment="详情") detail = Column(String(500), nullable=True, comment="详情")
created_at = Column(DateTime, default=datetime.now, comment="创建时间") created_at = Column(DateTime, default=datetime.now, comment="创建时间")
class StockAdjustFactor(Base):
"""股票复权系数表"""
__tablename__ = "stock_adjust_factors"
id = Column(BigInteger, primary_key=True, autoincrement=True)
symbol_id = Column(String(20), nullable=False, index=True, comment="标的代码")
trade_date = Column(String(10), nullable=False, index=True, comment="交易日期 YYYY-MM-DD")
qfq_factor = Column(Numeric(18, 8), nullable=False, default=1.0, comment="前复权系数")
hfq_factor = Column(Numeric(18, 8), nullable=False, default=1.0, comment="后复权系数")
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间")
__table_args__ = (
Index("idx_adj_factor_symbol_date", "symbol_id", "trade_date"),
)

@ -11,7 +11,7 @@ from app.models import (
) )
from app.repositories.models import ( from app.repositories.models import (
StockSymbol, StockKLine1M, StockKLine5M, StockKLine1D, StockSymbol, StockKLine1M, StockKLine5M, StockKLine1D,
StockTradingCalendar StockTradingCalendar, StockAdjustFactor
) )
@ -258,3 +258,85 @@ class StockRepository:
self.db.add(new_cal) self.db.add(new_cal)
self.db.commit() self.db.commit()
def get_adjust_factors(
self,
symbol: str,
start_date: str,
end_date: str
) -> List[dict]:
"""获取指定日期范围内的复权系数
Args:
symbol: 股票代码
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
Returns:
复权系数列表每项包含 trade_date, qfq_factor, hfq_factor
"""
# 转换日期格式
start_fmt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:]}"
end_fmt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:]}"
results = self.db.query(StockAdjustFactor).filter(
StockAdjustFactor.symbol_id == symbol,
StockAdjustFactor.trade_date >= start_fmt,
StockAdjustFactor.trade_date <= end_fmt
).order_by(StockAdjustFactor.trade_date.asc()).all()
return [
{
"trade_date": r.trade_date,
"qfq_factor": float(r.qfq_factor) if r.qfq_factor else 1.0,
"hfq_factor": float(r.hfq_factor) if r.hfq_factor else 1.0
}
for r in results
]
def save_adjust_factors(self, symbol: str, factors: List[dict]) -> None:
"""保存复权系数
Args:
symbol: 股票代码
factors: 复权系数列表每项包含 trade_date, qfq_factor, hfq_factor
"""
for f in factors:
trade_date = f.get("trade_date")
existing = self.db.query(StockAdjustFactor).filter(
StockAdjustFactor.symbol_id == symbol,
StockAdjustFactor.trade_date == trade_date
).first()
if existing:
existing.qfq_factor = f.get("qfq_factor", 1.0)
existing.hfq_factor = f.get("hfq_factor", 1.0)
else:
new_factor = StockAdjustFactor(
symbol_id=symbol,
trade_date=trade_date,
qfq_factor=f.get("qfq_factor", 1.0),
hfq_factor=f.get("hfq_factor", 1.0)
)
self.db.add(new_factor)
self.db.commit()
def get_latest_adjust_factor(self, symbol: str) -> Optional[dict]:
"""获取最新的复权系数
Returns:
包含 qfq_factor hfq_factor 的字典如果没有则返回None
"""
result = self.db.query(StockAdjustFactor).filter(
StockAdjustFactor.symbol_id == symbol
).order_by(StockAdjustFactor.trade_date.desc()).first()
if result:
return {
"trade_date": result.trade_date,
"qfq_factor": float(result.qfq_factor) if result.qfq_factor else 1.0,
"hfq_factor": float(result.hfq_factor) if result.hfq_factor else 1.0
}
return None

@ -203,16 +203,30 @@ class AdapterService:
# 从 config.json 获取最新配置(与文件同步) # 从 config.json 获取最新配置(与文件同步)
file_config = get_config() file_config = get_config()
print(f"Using file config: {file_config}, adapter name: {name}") print(f"Using file config: {file_config}, adapter name: {name}")
if name == "amazingdata":
# 优先使用 stock 下的 amazingdata 配置 # 尝试从配置文件中获取适配器配置
source_info = file_config.sources.stock.list["amazingdata"] adapter_config = None
# 1. 首先检查 stock 配置
if name in file_config.sources.stock.list:
source_info = file_config.sources.stock.list[name]
adapter_config = dict(source_info.config) if source_info else {}
print(f"Using stock config for {name}: {adapter_config}")
# 2. 然后检查 futures 配置
elif name in file_config.sources.futures.list:
source_info = file_config.sources.futures.list[name]
adapter_config = dict(source_info.config) if source_info else {} adapter_config = dict(source_info.config) if source_info else {}
print(f"Using amazingdata config: {adapter_config}") print(f"Using futures config for {name}: {adapter_config}")
# 处理 port 为字符串的情况
if "port" in adapter_config and isinstance(adapter_config["port"], str): # 3. 使用默认配置
adapter_config["port"] = int(adapter_config["port"]) if adapter_config["port"].strip() else 8600
else: else:
adapter_config = self.configs[name].get("config", {}) adapter_config = self.configs[name].get("config", {})
print(f"Using default config for {name}: {adapter_config}")
# 处理 port 为字符串的情况
if "port" in adapter_config and isinstance(adapter_config["port"], str):
adapter_config["port"] = int(adapter_config["port"]) if adapter_config["port"].strip() else 8600
cfg = {"enabled": self.configs[name].get("enabled", False), "config": adapter_config} cfg = {"enabled": self.configs[name].get("enabled", False), "config": adapter_config}

@ -61,10 +61,16 @@ class FuturesService:
# 确保适配器已连接 # 确保适配器已连接
adapter = adapter_service.get_active_adapter("futures") adapter = adapter_service.get_active_adapter("futures")
if not adapter: if not adapter:
# 尝试连接 amazingdata # 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.futures.active
# 尝试连接配置的适配器
info(f"Connecting to configured adapter: {active_source}")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_until_complete(adapter_service._connect_adapter("amazingdata")) loop.run_until_complete(adapter_service._connect_adapter(active_source))
loop.close() loop.close()
adapter = adapter_service.get_active_adapter("futures") adapter = adapter_service.get_active_adapter("futures")
@ -164,7 +170,13 @@ class FuturesService:
# 确保适配器已连接 # 确保适配器已连接
adapter = adapter_service.get_active_adapter("futures") adapter = adapter_service.get_active_adapter("futures")
if not adapter: if not adapter:
asyncio.run(adapter_service._connect_adapter("amazingdata")) # 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.futures.active
info(f"Connecting to configured adapter: {active_source}")
asyncio.run(adapter_service._connect_adapter(active_source))
adapter = adapter_service.get_active_adapter("futures") adapter = adapter_service.get_active_adapter("futures")
if not adapter: if not adapter:

@ -71,10 +71,16 @@ class StockService:
# 确保适配器已连接 # 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock") adapter = adapter_service.get_active_adapter("stock")
if not adapter: if not adapter:
# 尝试连接 amazingdata # 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
# 尝试连接配置的适配器
info(f"Connecting to configured adapter: {active_source}")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_until_complete(adapter_service._connect_adapter("amazingdata")) loop.run_until_complete(adapter_service._connect_adapter(active_source))
loop.close() loop.close()
adapter = adapter_service.get_active_adapter("stock") adapter = adapter_service.get_active_adapter("stock")
@ -151,13 +157,145 @@ class StockService:
def _apply_adjust( def _apply_adjust(
self, self,
symbol: str, symbol: str,
items: List, items: List[KLineItem],
adjust_type: AdjustType adjust_type: AdjustType
) -> List: ) -> List[KLineItem]:
"""应用复权计算TODO: 实现复权逻辑)""" """应用复权计算
# 复权计算需要从数据库获取复权系数
# 这里简化处理,直接返回原始数据 复权原理:
return items - 前复权(qfq): 以最新价格为基准将历史价格按比例缩小
- 后复权(hfq): 以历史最早价格为基准将后续价格按比例放大
"""
if not items or adjust_type == AdjustType.NONE:
return items
try:
# 获取日期范围
start_date = items[0].time.strftime("%Y%m%d")
end_date = items[-1].time.strftime("%Y%m%d")
# 从数据库获取复权系数
factors = self.repository.get_adjust_factors(symbol, start_date, end_date)
# 如果没有复权系数,尝试从适配器获取
if not factors:
factors = self._fetch_adjust_factors_from_adapter(symbol, start_date, end_date)
if factors:
self.repository.save_adjust_factors(symbol, factors)
# 将复权系数转换为字典,方便查找
factor_map = {f["trade_date"]: f for f in factors}
# 应用复权
adjusted_items = []
for item in items:
# 获取交易日期
trade_date = getattr(item, 'trade_date', None)
if not trade_date and hasattr(item, 'time'):
trade_date = item.time.strftime("%Y-%m-%d")
factor = factor_map.get(trade_date, {"qfq_factor": 1.0, "hfq_factor": 1.0})
# 根据复权类型选择系数
if adjust_type == AdjustType.QFQ:
adj_factor = factor.get("qfq_factor", 1.0)
else: # HFQ
adj_factor = factor.get("hfq_factor", 1.0)
# 应用复权系数到价格字段
adjusted_item = KLineItem(
symbol=item.symbol,
time=item.time,
open=round(item.open * adj_factor, 4),
high=round(item.high * adj_factor, 4),
low=round(item.low * adj_factor, 4),
close=round(item.close * adj_factor, 4),
volume=item.volume,
amount=round(item.amount * adj_factor, 4) if item.amount else item.amount,
trade_date=getattr(item, 'trade_date', None),
is_limit_up=getattr(item, 'is_limit_up', None),
is_limit_down=getattr(item, 'is_limit_down', None),
total_market_cap=getattr(item, 'total_market_cap', None),
float_market_cap=getattr(item, 'float_market_cap', None),
inst_holding_ratio=getattr(item, 'inst_holding_ratio', None),
trading_days=getattr(item, 'trading_days', None),
adj_factor=adj_factor
)
adjusted_items.append(adjusted_item)
return adjusted_items
except Exception as e:
error(f"Failed to apply adjust factor for {symbol}: {e}")
# 出错时返回原始数据
return items
def _fetch_adjust_factors_from_adapter(
self,
symbol: str,
start_date: str,
end_date: str
) -> List[dict]:
"""从适配器获取复权系数"""
try:
adapter_service = AdapterService()
adapter = adapter_service.get_active_adapter("stock")
if not adapter:
error("No active adapter available for fetching adjust factors")
return []
# 检查适配器是否支持获取复权因子
if not hasattr(adapter, 'get_adj_factor'):
return []
# 异步获取前复权因子
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
qfq_df = loop.run_until_complete(
adapter.get_adj_factor([symbol])
)
hfq_df = loop.run_until_complete(
adapter.get_backward_factor([symbol])
)
finally:
loop.close()
# 转换DataFrame为列表
factors = []
# 处理日期格式
for idx in qfq_df.index:
date_obj = idx if hasattr(idx, 'strftime') else datetime.strptime(str(idx), "%Y%m%d")
date_str = date_obj.strftime("%Y-%m-%d")
date_key = date_obj.strftime("%Y%m%d")
# 只保留指定范围内的数据
if not (start_date <= date_key <= end_date):
continue
qfq_factor = float(qfq_df.loc[idx, symbol]) if symbol in qfq_df.columns else 1.0
hfq_factor = float(hfq_df.loc[idx, symbol]) if symbol in hfq_df.columns else 1.0
# 确保复权系数有效
if qfq_factor <= 0 or qfq_factor != qfq_factor: # 检查NaN
qfq_factor = 1.0
if hfq_factor <= 0 or hfq_factor != hfq_factor:
hfq_factor = 1.0
factors.append({
"trade_date": date_str,
"qfq_factor": qfq_factor,
"hfq_factor": hfq_factor
})
info(f"Fetched {len(factors)} adjust factors from adapter for {symbol}")
return factors
except Exception as e:
error(f"Failed to fetch adjust factors from adapter: {e}")
return []
def list_symbols(self, req: SymbolListRequest) -> SymbolListData: def list_symbols(self, req: SymbolListRequest) -> SymbolListData:
"""查询标的列表""" """查询标的列表"""
@ -196,7 +334,13 @@ class StockService:
# 确保适配器已连接 # 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock") adapter = adapter_service.get_active_adapter("stock")
if not adapter: if not adapter:
asyncio.run(adapter_service._connect_adapter("amazingdata")) # 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
info(f"Connecting to configured adapter: {active_source}")
asyncio.run(adapter_service._connect_adapter(active_source))
adapter = adapter_service.get_active_adapter("stock") adapter = adapter_service.get_active_adapter("stock")
if not adapter: if not adapter:
@ -301,7 +445,13 @@ class StockService:
# 确保适配器已连接 # 确保适配器已连接
adapter = adapter_service.get_active_adapter("stock") adapter = adapter_service.get_active_adapter("stock")
if not adapter: if not adapter:
asyncio.run(adapter_service._connect_adapter("amazingdata")) # 从配置获取当前激活的适配器名称
from app.core.config import get_config
config = get_config()
active_source = config.sources.stock.active
info(f"Connecting to configured adapter: {active_source}")
asyncio.run(adapter_service._connect_adapter(active_source))
adapter = adapter_service.get_active_adapter("stock") adapter = adapter_service.get_active_adapter("stock")
if not adapter: if not adapter:

@ -20,32 +20,32 @@
}, },
"sources": { "sources": {
"stock": { "stock": {
"active": "custom", "active": "amazingdata",
"list": { "list": {
"custom": { "amazingdata": {
"type": "sdk", "type": "sdk",
"config": { "config": {
"username": "", "username": "11200008169",
"password": "", "password": "11200008169@2026",
"host": "", "host": "140.206.44.234",
"port": "", "port": "8600",
"local_path": "./custom_data_cache/", "local_path": "./amazing_data_cache/",
"use_local_cache": "true" "use_local_cache": "true"
} }
} }
} }
}, },
"futures": { "futures": {
"active": "custom", "active": "amazingdata",
"list": { "list": {
"custom": { "amazingdata": {
"type": "sdk", "type": "sdk",
"config": { "config": {
"username": "", "username": "",
"password": "", "password": "",
"host": "", "host": "",
"port": "", "port": "8600",
"local_path": "./custom_data_cache/", "local_path": "./amazing_data_cache/",
"use_local_cache": "true" "use_local_cache": "true"
} }
} }

@ -0,0 +1,269 @@
Metadata-Version: 2.4
Name: market-data-service
Version: 1.0.0
Summary: 统一行情数据服务 - Python实现
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: fastapi>=0.115.0
Requires-Dist: uvicorn[standard]>=0.32.0
Requires-Dist: python-socketio>=5.12.1
Requires-Dist: websockets>=14.1
Requires-Dist: sqlalchemy>=2.0.36
Requires-Dist: psycopg2-binary>=2.9.10
Requires-Dist: pandas>=2.2.3
Requires-Dist: numpy>=2.1.3
Requires-Dist: numba>=0.61.0
Requires-Dist: scipy>=1.15.0
Requires-Dist: pydantic>=2.10.0
Requires-Dist: pydantic-settings>=2.6.1
Requires-Dist: python-dotenv>=1.0.1
Requires-Dist: PyYAML>=6.0.2
Requires-Dist: httpx>=0.28.0
Requires-Dist: apscheduler>=3.11.0
Provides-Extra: dev
Requires-Dist: pytest>=8.3.4; extra == "dev"
Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
# 统一行情数据服务 - Python实现
Python版本的统一行情数据服务所有接口和功能与Go版本保持一致。
## 特性
- **多周期K线支持**1m/5m/15m/30m/60m/1d/1w/1month
- **股票复权支持**:前复权(qfq)/后复权(hfq)
- **数据源热切换**支持Wind、Tushare等多个数据源动态切换
- **双轨设计**:股票和期货接口独立,数据存储隔离
- **WebSocket实时订阅**:支持实时行情推送
- **数据质量监控**:自动检测数据缺失并告警
- **交易日历**:支持查询股票和期货的交易日历
- **期货合约查询**:根据品种获取可交易合约列表
## 技术栈
- **语言**: Python 3.10+
- **Web框架**: FastAPI
- **WebSocket**: FastAPI原生WebSocket + python-socketio
- **数据库**: PostgreSQL 15+ (SQLAlchemy ORM)
- **数据源**: Tushare (首期支持)
## 项目结构
```
python_market_data_service/
├── app/
│ ├── __init__.py
│ ├── main.py # 主程序入口
│ ├── api/ # API路由
│ │ ├── __init__.py
│ │ ├── routes.py # 主要API路由
│ │ └── admin_routes.py # 管理后台路由
│ ├── core/ # 核心模块
│ │ ├── __init__.py
│ │ ├── config.py # 配置管理
│ │ ├── errors.py # 错误定义
│ │ └── logger.py # 日志工具
│ ├── models/ # 数据模型
│ │ ├── __init__.py
│ │ ├── types.py # 基础类型
│ │ └── admin_types.py # 管理后台类型
│ ├── repositories/ # 数据访问层
│ │ ├── __init__.py
│ │ ├── database.py # 数据库连接
│ │ ├── models.py # 数据库模型
│ │ ├── stock_repository.py
│ │ └── futures_repository.py
│ ├── services/ # 业务逻辑层
│ │ ├── __init__.py
│ │ ├── stock_service.py
│ │ ├── futures_service.py
│ │ ├── admin_service.py
│ │ ├── config_service.py
│ │ ├── adapter_service.py
│ │ └── test_service.py
│ ├── adapters/ # 数据源适配器
│ │ ├── __init__.py
│ │ ├── base.py # 适配器基类
│ │ └── tushare_adapter.py
│ └── websocket/ # WebSocket服务
│ ├── __init__.py
│ └── server.py
├── scripts/
│ └── sync_data.py # 数据同步工具
├── tests/ # 测试文件
├── requirements.txt # 依赖列表
├── pyproject.toml # 项目配置
└── README.md # 本文件
```
## 快速开始
### 1. 环境准备
- Python 3.10+
- PostgreSQL 15+
- Tushare Token (从 [Tushare官网](https://tushare.pro) 获取)
### 2. 安装依赖
```bash
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
# Windows:
venv\Scripts\activate
# Linux/Mac:
source venv/bin/activate
# 安装依赖
pip install -r requirements.txt
# 安装Tushare需单独安装
pip install tushare
```
### 3. 配置环境变量
```bash
# Windows PowerShell
$env:TUSHARE_TOKEN="your_tushare_token"
$env:DATABASE_URL="postgresql://user:password@localhost:5432/marketdata"
# Linux/Mac
export TUSHARE_TOKEN="your_tushare_token"
export DATABASE_URL="postgresql://user:password@localhost:5432/marketdata"
```
### 4. 初始化数据库
```bash
# 创建数据库使用psql或pgAdmin
createdb marketdata
# 启动服务时会自动创建表结构
```
### 5. 启动服务
```bash
# 开发模式
python -m app.main
# 或使用uvicorn
uvicorn app.main:app --reload --port 8080
```
服务将启动在 `http://localhost:8080`
- API文档: `http://localhost:8080/docs`
- 管理后台: `http://localhost:8080/admin`
### 6. 同步基础数据
```bash
# 同步股票列表
python scripts/sync_data.py --type stocks
# 同步期货列表
python scripts/sync_data.py --type futures
# 同步交易日历
python scripts/sync_data.py --type calendar --start 20240101 --end 20241231
# 同步K线数据
python scripts/sync_data.py --type klines --symbol 000001.SZ --start 20240301 --end 20240307 --freq 1d
```
## API接口
### 股票接口
| 接口 | 方法 | 说明 |
|------|------|------|
| `/v1/stock/klines/:symbol` | GET | 查询K线数据 |
| `/v1/stock/symbols` | GET | 查询标的列表 |
| `/v1/stock/klines/batch` | POST | 批量查询K线 |
| `/v1/stock/trading-dates` | GET | 获取交易日历 |
### 期货接口
| 接口 | 方法 | 说明 |
|------|------|------|
| `/v1/futures/klines/:symbol` | GET | 查询K线数据 |
| `/v1/futures/symbols` | GET | 查询标的列表 |
| `/v1/futures/klines/batch` | POST | 批量查询K线 |
| `/v1/futures/continuous/:underlying` | GET | 查询主力连续合约(预留) |
| `/v1/futures/trading-dates` | GET | 获取交易日历 |
| `/v1/futures/contracts` | GET | 获取品种合约列表 |
### 管理接口
| 接口 | 方法 | 说明 |
|------|------|------|
| `/v1/admin/source/status` | GET | 获取数据源状态 |
| `/v1/admin/source/switch` | POST | 切换数据源 |
| `/v1/admin/backfill` | POST | 历史数据补录 |
| `/v1/admin/health` | GET | 健康检查 |
### 管理后台
服务启动后,访问 `http://localhost:8080/admin` 进入管理后台。
### WebSocket实时订阅
**连接地址**: `ws://localhost:8080/v1/stream`
**认证**: 连接时在Header中传递 `X-API-Key`
**客户端消息**:
```json
// 订阅
{
"action": "subscribe",
"symbols": ["000001.SZ", "CU2504.SHFE"]
}
// 取消订阅
{
"action": "unsubscribe",
"symbols": ["000001.SZ"]
}
```
**服务器消息**:
```json
// 订阅确认
{
"type": "ack",
"action": "subscribe",
"symbols": ["000001.SZ", "CU2504.SHFE"],
"ts": "2025-03-07T12:30:00Z"
}
// 心跳
{
"type": "heartbeat",
"ts": "2025-03-07T12:30:30Z"
}
```
**限制**: 单连接最大订阅100个标的
## 与Go版本的主要区别
1. **Web框架**: Gin -> FastAPI
2. **ORM**: 原生SQL -> SQLAlchemy
3. **WebSocket**: Gorilla -> FastAPI原生
4. **配置**: 文件+环境变量 -> Pydantic Settings
5. **API文档**: 自动生成Swagger/ReDoc
## License
MIT

@ -0,0 +1,43 @@
README.md
pyproject.toml
app/__init__.py
app/main.py
app/adapters/__init__.py
app/adapters/amazingdata_adapter.py
app/adapters/base.py
app/api/__init__.py
app/api/admin_routes.py
app/api/routes.py
app/core/__init__.py
app/core/config.py
app/core/errors.py
app/core/logger.py
app/core/metrics.py
app/core/rate_limiter.py
app/models/__init__.py
app/models/admin_types.py
app/models/types.py
app/monitor/__init__.py
app/monitor/alert_channels.py
app/monitor/monitor.py
app/repositories/__init__.py
app/repositories/database.py
app/repositories/futures_repository.py
app/repositories/models.py
app/repositories/stock_repository.py
app/services/__init__.py
app/services/adapter_service.py
app/services/admin_service.py
app/services/config_service.py
app/services/futures_service.py
app/services/stock_service.py
app/services/test_service.py
app/websocket/__init__.py
app/websocket/server.py
market_data_service.egg-info/PKG-INFO
market_data_service.egg-info/SOURCES.txt
market_data_service.egg-info/dependency_links.txt
market_data_service.egg-info/requires.txt
market_data_service.egg-info/top_level.txt
tests/test_xysz_adapter.py
tests/test_xysz_integration.py

@ -0,0 +1,20 @@
fastapi>=0.115.0
uvicorn[standard]>=0.32.0
python-socketio>=5.12.1
websockets>=14.1
sqlalchemy>=2.0.36
psycopg2-binary>=2.9.10
pandas>=2.2.3
numpy>=2.1.3
numba>=0.61.0
scipy>=1.15.0
pydantic>=2.10.0
pydantic-settings>=2.6.1
python-dotenv>=1.0.1
PyYAML>=6.0.2
httpx>=0.28.0
apscheduler>=3.11.0
[dev]
pytest>=8.3.4
pytest-asyncio>=0.24.0

Binary file not shown.

Binary file not shown.

@ -29,6 +29,7 @@ aioredis==2.0.1
# Monitoring # Monitoring
apscheduler==3.11.0 apscheduler==3.11.0
prometheus-client==0.21.0
# Testing # Testing
pytest==8.3.4 pytest==8.3.4

@ -0,0 +1,53 @@
"""MySQL数据库初始化脚本
创建数据库和表结构
"""
import sys
sys.path.insert(0, '.')
from sqlalchemy import create_engine, text
from app.core.config import get_config
from app.repositories.database import Base
def init_mysql():
"""初始化MySQL数据库"""
config = get_config()
db_config = config.database
# 连接MySQL服务器不指定数据库
server_url = f"mysql+pymysql://{db_config.user}:{db_config.password}@{db_config.host}:{db_config.port}"
print(f"Connecting to MySQL server: {db_config.host}:{db_config.port}")
engine = create_engine(server_url)
# 创建数据库
with engine.connect() as conn:
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {db_config.database} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"))
print(f"Database '{db_config.database}' created or exists")
# 连接到新创建的数据库
db_url = f"{server_url}/{db_config.database}"
db_engine = create_engine(db_url)
# 创建所有表
print("Creating tables...")
Base.metadata.create_all(bind=db_engine)
print("Tables created successfully!")
# 显示创建的表
with db_engine.connect() as conn:
result = conn.execute(text("SHOW TABLES"))
tables = [row[0] for row in result]
print(f"\nTables in database '{db_config.database}':")
for table in tables:
print(f" - {table}")
if __name__ == "__main__":
try:
init_mysql()
print("\nMySQL database initialization completed!")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
Loading…
Cancel
Save