feat: 初始化工程代码

master
Lxy 2 weeks ago
commit 4eaee5c594

@ -0,0 +1,119 @@
# 数据缓冲平台
基于 FastAPI + SQLite + APScheduler 的行情数据缓冲服务。
## 功能
1. **数据采集缓存** - 复用现有采集脚本,自动缓存到 SQLite
2. **批量获取接口** - 指定品种+周期,批量拉取并缓存
3. **定时任务管理** - 创建/启动/停止/删除自动轮询任务
4. **最新数据接口** - 从缓存中快速获取最新数据
## 快速启动
```bash
cd buffer_platform
pip install -r requirements.txt
# 启动服务(默认端口 8600
python -m app.main
# 或指定端口
BUFFER_PORT=9000 python -m app.main
```
## API 接口
### 数据接口 `/api/v1/data`
| 方法 | 路径 | 说明 |
|------|------|------|
| POST | `/data/batch-fetch` | 批量获取并缓存(智能缓存) |
| GET | `/data/latest/{symbol}` | 从缓存获取最新数据 |
| GET | `/data/latest/{symbol}/{period}` | 获取指定周期最新数据 |
| GET | `/data/cache-status/{symbol}` | 查看缓存状态 |
### 品种配置接口 `/api/v1/config`
| 方法 | 路径 | 说明 |
|------|------|------|
| GET | `/config` | 获取当前品种配置 |
| POST | `/config/upload` | 上传品种配置文件JSON |
| POST | `/config/batch-fetch-all` | 根据配置批量获取所有品种数据 |
| POST | `/config/batch-tasks` | 根据配置批量创建定时任务 |
### 定时任务接口 `/api/v1/tasks`
| 方法 | 路径 | 说明 |
|------|------|------|
| POST | `/tasks` | 创建并启动定时任务 |
| GET | `/tasks` | 列出所有任务 |
| POST | `/tasks/{id}/start` | 启动任务 |
| POST | `/tasks/{id}/stop` | 停止任务 |
| POST | `/tasks/{id}/update-interval` | 更新轮询间隔 |
| DELETE | `/tasks/{id}` | 删除任务 |
### UI 页面
| 路径 | 说明 |
|------|------|
| `/ui` | 品种配置管理页面(上传文件、批量获取、批量任务) |
| `/docs` | Swagger API 文档 |
### 使用示例
```bash
# 启动服务(默认端口 8600
cd buffer_platform
python -m app.main
# 访问 UI 管理页面
open http://localhost:8600/ui
# 上传品种配置文件
curl -X POST http://localhost:8600/api/v1/config/upload \
-F "file=@symbols_config.json"
# 查看当前配置
curl http://localhost:8600/api/v1/config
# 根据配置批量获取所有品种数据
curl -X POST 'http://localhost:8600/api/v1/config/batch-fetch-all?periods=5min,15min,60min&data_type=futures'
# 根据配置批量创建定时任务每5分钟自动采集
curl -X POST 'http://localhost:8600/api/v1/config/batch-tasks?periods=5min,15min,60min&interval_seconds=300&data_type=futures'
# 批量获取(手动指定品种)
curl -X POST http://localhost:8600/api/v1/data/batch-fetch \
-H "Content-Type: application/json" \
-d '{"symbols": ["SN2504", "AG2506"], "periods": ["5min", "15min"]}'
# 获取最新缓存
curl http://localhost:8600/api/v1/data/latest/SN2504
# 创建单个定时任务
curl -X POST http://localhost:8600/api/v1/tasks \
-H "Content-Type: application/json" \
-d '{"symbol": "SN2504", "periods": ["5min", "15min", "60min"], "interval_seconds": 300}'
```
## 项目结构
```
buffer_platform/
├── app/
│ ├── main.py # FastAPI 入口
│ ├── config.py # 配置
│ ├── database.py # 数据库连接
│ ├── models.py # ORM 模型
│ ├── schemas.py # 请求/响应模型
│ ├── api/
│ │ ├── data.py # 数据接口
│ │ └── tasks.py # 任务接口
│ └── services/
│ ├── collector.py # 采集服务
│ ├── cache.py # 缓存服务
│ └── scheduler.py # 调度服务
├── data/ # SQLite 数据库文件
└── requirements.txt
```

@ -0,0 +1 @@
# 数据缓冲平台

@ -0,0 +1,280 @@
"""
配置管理接口 - 品种配置文件上传批量获取批量任务创建
"""
import json
import logging
import shutil
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Body
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from pydantic import BaseModel
from app.database import get_db
from app.services.collector import fetch_symbol_data
from app.services.cache import save_market_data, check_cache_status, get_cached_data, create_task
from app.services.scheduler import add_job
from app.schemas import CandleItem, TimeframeData, SymbolDataResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/config", tags=["品种配置"])
class BatchFetchRequest(BaseModel):
"""批量获取请求体"""
periods: Optional[str] = None
data_type: str = "futures"
selected_symbols: Optional[str] = None # 逗号分隔的合约代码
# 配置文件存储路径
CONFIG_DIR = Path(__file__).resolve().parent.parent.parent / "config"
CONFIG_FILE = CONFIG_DIR / "symbols_config.json"
def _ensure_config_dir():
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
@router.get("")
def get_config():
"""获取当前品种配置"""
_ensure_config_dir()
if not CONFIG_FILE.exists():
return {"futures": {}, "stock": {}}
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
return json.load(f)
@router.post("/upload")
def upload_config(
file: Optional[UploadFile] = File(None),
json_config: Optional[dict] = Body(None, embed=False),
):
"""
上传品种配置文件JSON格式
格式示例:
{
"futures": {"沪银": "AG2606", "沪金": "AU2606"},
"stock": {"平安银行": "000001"}
}
"""
_ensure_config_dir()
try:
if file:
content = file.file.read()
data = json.loads(content)
elif json_config:
data = json_config
else:
raise HTTPException(status_code=400, detail="请提供配置文件或JSON数据")
if not isinstance(data, dict):
raise HTTPException(status_code=400, detail="配置文件必须是 JSON 对象")
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
futures_count = len(data.get("futures", {}))
stock_count = len(data.get("stock", {}))
return {
"message": "配置文件上传成功",
"futures_symbols": futures_count,
"stock_symbols": stock_count,
"symbols": data,
}
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="无效的 JSON 格式")
@router.post("/batch-fetch-all")
def batch_fetch_all(
request: BatchFetchRequest,
db: Session = Depends(get_db),
):
"""
根据配置文件批量获取所有品种数据
智能缓存:已存在且有效的数据不重复请求
"""
periods = request.periods
data_type = request.data_type
selected_symbols = request.selected_symbols
_ensure_config_dir()
if not CONFIG_FILE.exists():
raise HTTPException(status_code=400, detail="请先上传品种配置文件")
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
config = json.load(f)
symbols_dict = config.get(data_type, {})
if not symbols_dict:
raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种")
# 如果指定了selected_symbols只获取这些合约
if selected_symbols:
# 解析逗号分隔的合约代码
symbol_list = [s.strip() for s in selected_symbols.split(",") if s.strip()]
symbols_dict = {name: code for name, code in symbols_dict.items() if code in symbol_list}
if not symbols_dict:
raise HTTPException(status_code=400, detail="选定的合约不在配置中")
period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"]
results = {
"total": len(symbols_dict),
"success": [],
"failed": [],
"cached": [], # 命中缓存的
"details": {},
}
for name, symbol in symbols_dict.items():
logger.info(f"处理品种: {name} ({symbol})")
# 检查缓存
status = check_cache_status(db, symbol, data_type, period_list)
if status["all_valid"]:
results["cached"].append({"name": name, "symbol": symbol})
cached = get_cached_data(db, symbol, data_type, period_list)
timeframes = []
for p, candles in cached["timeframes"].items():
# 转换数据格式: time -> datetime
normalized_candles = []
for c in candles:
candle_dict = dict(c)
if 'time' in candle_dict and 'datetime' not in candle_dict:
candle_dict['datetime'] = candle_dict.pop('time')
normalized_candles.append(candle_dict)
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in normalized_candles],
candle_count=len(normalized_candles),
fetched_at=cached.get("timestamp", ""),
))
results["details"][symbol] = SymbolDataResponse(
symbol=symbol,
data_type=data_type,
current_price=cached.get("current_price"),
timeframes=timeframes,
source="cache",
)
results["success"].append({"name": name, "symbol": symbol})
continue
# 需要采集
need_fetch = status["missing_periods"]
logger.info(f"需要采集的周期: {need_fetch}")
result = fetch_symbol_data(symbol, data_type, need_fetch)
if result.get("timeframes"):
logger.info(f"采集到 {len(result['timeframes'])} 个周期的数据,开始保存")
save_market_data(db, symbol, result)
# 合并缓存和新数据
all_timeframes = {}
if status["valid_periods"]:
existing = get_cached_data(db, symbol, data_type, status["valid_periods"])
if existing:
all_timeframes.update(existing["timeframes"])
all_timeframes.update(result["timeframes"])
timeframes = []
for p in period_list:
candles = all_timeframes.get(p, [])
if candles:
# 转换数据格式: time -> datetime
normalized_candles = []
for c in candles:
candle_dict = dict(c)
if 'time' in candle_dict and 'datetime' not in candle_dict:
candle_dict['datetime'] = candle_dict.pop('time')
normalized_candles.append(candle_dict)
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in normalized_candles],
candle_count=len(normalized_candles),
fetched_at=result.get("timestamp", ""),
))
source = "live+cache" if status["valid_periods"] else "live"
results["details"][symbol] = SymbolDataResponse(
symbol=symbol,
data_type=data_type,
current_price=result.get("current_price"),
timeframes=timeframes,
source=source,
)
results["success"].append({"name": name, "symbol": symbol})
logger.info(f"采集成功: {symbol}")
else:
error_msg = result.get("error", "未知错误")
logger.error(f"采集失败: {symbol}, 错误: {error_msg}")
results["failed"].append({
"name": name,
"symbol": symbol,
"error": error_msg,
})
return results
@router.post("/batch-tasks")
def batch_create_tasks(
periods: Optional[str] = None,
interval_seconds: int = 300,
data_type: str = "futures",
db: Session = Depends(get_db),
):
"""
根据配置文件为所有品种批量创建定时任务
"""
_ensure_config_dir()
if not CONFIG_FILE.exists():
raise HTTPException(status_code=400, detail="请先上传品种配置文件")
with open(CONFIG_FILE, "r", encoding="utf-8") as f:
config = json.load(f)
symbols_dict = config.get(data_type, {})
if not symbols_dict:
raise HTTPException(status_code=400, detail=f"配置中没有 {data_type} 类型的品种")
period_list = [p.strip() for p in periods.split(",")] if periods else ["5min", "15min", "30min", "60min", "daily"]
results = {"total": len(symbols_dict), "created": [], "failed": []}
for name, symbol in symbols_dict.items():
try:
task = create_task(
db=db,
symbol=symbol,
data_type=data_type,
periods=period_list,
interval_seconds=interval_seconds,
)
job_id = add_job(task.id, task.interval_seconds)
task.job_id = job_id
db.commit()
db.refresh(task)
results["created"].append({
"name": name,
"symbol": symbol,
"task_id": task.id,
"job_id": job_id,
"interval": interval_seconds,
})
except Exception as e:
results["failed"].append({
"name": name,
"symbol": symbol,
"error": str(e),
})
return results

@ -0,0 +1,208 @@
"""
数据接口 - 批量获取 / 获取最新缓存
"""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy.orm import Session
from app.database import get_db
from app.schemas import (
BatchFetchRequest,
BatchFetchResponse,
LatestDataResponse,
CandleItem,
TimeframeData,
SymbolDataResponse,
)
from app.services.collector import fetch_symbol_data, fetch_batch
from app.services.cache import (
save_market_data,
get_cached_data,
get_latest_cached,
check_cache_status,
)
from app.config import CACHE_TTL_SECONDS
from datetime import datetime
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/data", tags=["数据"])
@router.post("/batch-fetch", response_model=BatchFetchResponse)
def batch_fetch(req: BatchFetchRequest, db: Session = Depends(get_db)):
"""
批量获取指定品种指定周期的数据
智能缓存已存在且有效的数据不重复请求
"""
symbols = req.symbols
periods = req.periods
data_type = req.data_type
success = []
failed = []
details = {}
for sym in symbols:
status = check_cache_status(db, sym, data_type, periods)
if status["all_valid"]:
logger.info(f"[{sym}] 缓存全部命中,跳过采集")
cached = get_cached_data(db, sym, data_type, periods)
timeframes = []
for p, candles in cached["timeframes"].items():
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in candles],
candle_count=len(candles),
fetched_at=cached.get("timestamp", ""),
))
details[sym] = SymbolDataResponse(
symbol=sym,
data_type=data_type,
current_price=cached.get("current_price"),
timeframes=timeframes,
source="cache",
)
success.append(sym)
continue
need_fetch = status["missing_periods"]
logger.info(f"[{sym}] 缓存部分缺失,需要采集: {need_fetch}")
result = fetch_symbol_data(sym, data_type, need_fetch)
if result.get("timeframes"):
save_market_data(db, sym, result)
success.append(sym)
all_timeframes = {}
if status["valid_periods"]:
existing = get_cached_data(db, sym, data_type, status["valid_periods"])
if existing:
all_timeframes.update(existing["timeframes"])
all_timeframes.update(result["timeframes"])
timeframes = []
for p in periods:
candles = all_timeframes.get(p, [])
if candles:
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in candles],
candle_count=len(candles),
fetched_at=result.get("timestamp", ""),
))
details[sym] = SymbolDataResponse(
symbol=sym,
data_type=data_type,
current_price=result.get("current_price"),
timeframes=timeframes,
source="live+cache",
)
else:
failed.append(sym)
details[sym] = {"error": result.get("error", "未知错误")}
return BatchFetchResponse(
success=success,
failed=failed,
details=details,
)
@router.get("/latest/{symbol}", response_model=SymbolDataResponse)
def get_latest(
symbol: str,
data_type: str = "futures",
period: Optional[str] = None,
db: Session = Depends(get_db),
):
"""
从缓存获取最新数据
可指定单个 period不指定则返回所有已缓存周期
"""
cached = get_cached_data(db, symbol, data_type, [period] if period else None)
if not cached:
raise HTTPException(status_code=404, detail=f"未找到 {symbol} 的缓存数据")
timeframes = []
for p, candles in cached["timeframes"].items():
# 转换数据格式: time -> datetime
normalized_candles = []
for c in candles:
candle_dict = dict(c)
if 'time' in candle_dict and 'datetime' not in candle_dict:
candle_dict['datetime'] = candle_dict.pop('time')
normalized_candles.append(candle_dict)
timeframes.append(TimeframeData(
period=p,
candles=[CandleItem(**c) for c in normalized_candles],
candle_count=len(normalized_candles),
fetched_at=cached.get("timestamp", ""),
))
return SymbolDataResponse(
symbol=symbol,
data_type=data_type,
current_price=cached.get("current_price"),
timeframes=timeframes,
source="cache" if cached.get("is_fresh", False) else "cache_stale",
)
@router.get("/latest/{symbol}/{period}")
def get_latest_by_period(
symbol: str,
period: str,
data_type: str = "futures",
db: Session = Depends(get_db),
):
"""
获取缓存中指定品种+周期的最新数据
返回单个周期的 K 线
"""
cached = get_cached_data(db, symbol, data_type, [period])
if not cached:
raise HTTPException(status_code=404, detail=f"未找到 {symbol} {period} 的缓存")
candles = cached["timeframes"].get(period, [])
return {
"symbol": symbol,
"period": period,
"data_type": data_type,
"candles": candles,
"candle_count": len(candles),
"current_price": cached.get("current_price"),
"fetched_at": cached.get("timestamp"),
"is_fresh": cached.get("is_fresh", False),
}
@router.get("/cache-status/{symbol}")
def cache_status(symbol: str, db: Session = Depends(get_db)):
"""查看品种的缓存状态"""
records = get_latest_cached(db, symbol)
if not records:
return {"symbol": symbol, "cached_periods": [], "status": "no_data"}
now = datetime.now()
periods_info = []
for r in records:
age_seconds = (now - r.fetched_at).total_seconds()
periods_info.append({
"period": r.period,
"candle_count": r.candle_count,
"fetched_at": r.fetched_at.isoformat(),
"age_seconds": round(age_seconds, 0),
"is_fresh": age_seconds < CACHE_TTL_SECONDS,
})
return {
"symbol": symbol,
"cached_periods": periods_info,
"status": "ok",
}

@ -0,0 +1,161 @@
"""
定时任务接口 - 创建/启动/停止/删除/列表
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.database import get_db
from app.schemas import (
CreateTaskRequest,
TaskInfo,
TaskListResponse,
)
from app.services.cache import (
create_task,
list_tasks,
get_task,
disable_task,
enable_task,
delete_task,
)
from app.services.scheduler import (
add_job,
remove_job,
is_job_running,
get_all_jobs,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/tasks", tags=["定时任务"])
@router.post("", response_model=TaskInfo)
def create_new_task(req: CreateTaskRequest, db: Session = Depends(get_db)):
"""
创建并启动一个定时采集任务
输入品种合约和轮询时长自动开始定时获取数据
"""
task = create_task(
db=db,
symbol=req.symbol,
data_type=req.data_type,
periods=req.periods,
interval_seconds=req.interval_seconds,
)
# 注册到调度器
job_id = add_job(task.id, task.interval_seconds)
task.job_id = job_id
db.commit()
db.refresh(task)
return _to_task_info(task)
@router.get("", response_model=TaskListResponse)
def list_all_tasks(db: Session = Depends(get_db)):
"""列出所有定时任务"""
tasks = list_tasks(db)
job_status = get_all_jobs()
task_infos = []
for t in tasks:
running = is_job_running(t.id) if t.enabled else False
task_infos.append(TaskInfo(
id=t.id,
symbol=t.symbol,
data_type=t.data_type,
periods=t.periods.split(",") if t.periods else [],
interval_seconds=t.interval_seconds,
enabled=t.enabled,
running=running,
last_run=t.last_run.isoformat() if t.last_run else None,
last_status=t.last_status,
created_at=t.created_at.isoformat(),
updated_at=t.updated_at.isoformat(),
))
return TaskListResponse(tasks=task_infos, total=len(task_infos))
@router.post("/{task_id}/stop", response_model=TaskInfo)
def stop_task(task_id: int, db: Session = Depends(get_db)):
"""停止定时任务(从调度器移除,但保留配置)"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
remove_job(task_id)
task = disable_task(db, task_id)
return _to_task_info(task)
@router.post("/{task_id}/start", response_model=TaskInfo)
def start_task(task_id: int, db: Session = Depends(get_db)):
"""重新启动已停止的定时任务"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
enable_task(db, task_id)
add_job(task.id, task.interval_seconds)
db.refresh(task)
return _to_task_info(task)
@router.delete("/{task_id}")
def delete_existing_task(task_id: int, db: Session = Depends(get_db)):
"""删除定时任务(同时从调度器移除)"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
remove_job(task_id)
delete_task(db, task_id)
return {"message": f"任务 {task_id} 已删除"}
@router.post("/{task_id}/update-interval", response_model=TaskInfo)
def update_interval(
task_id: int,
interval_seconds: int,
db: Session = Depends(get_db),
):
"""更新任务的轮询间隔"""
task = get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
task.interval_seconds = interval_seconds
task.updated_at = task.updated_at.__class__.now()
db.commit()
db.refresh(task)
# 如果任务正在运行,更新调度器
if task.enabled and is_job_running(task_id):
remove_job(task_id)
add_job(task.id, task.interval_seconds)
return _to_task_info(task)
def _to_task_info(task) -> TaskInfo:
"""ORM -> Pydantic"""
return TaskInfo(
id=task.id,
symbol=task.symbol,
data_type=task.data_type,
periods=task.periods.split(",") if task.periods else [],
interval_seconds=task.interval_seconds,
enabled=task.enabled,
running=is_job_running(task.id),
last_run=task.last_run.isoformat() if task.last_run else None,
last_status=task.last_status,
created_at=task.created_at.isoformat(),
updated_at=task.updated_at.isoformat(),
)

@ -0,0 +1,36 @@
"""
数据缓冲平台 - 配置
"""
import os
from pathlib import Path
# 项目根目录
BASE_DIR = Path(__file__).resolve().parent.parent.parent
# 数据库路径
DB_PATH = Path(os.getenv(
"BUFFER_DB_PATH",
str(Path(__file__).resolve().parent.parent / "data" / "buffer.db")
))
# 原始采集脚本路径
COLLECTOR_SCRIPT = os.getenv(
"COLLECTOR_SCRIPT",
str(BASE_DIR / "market_data_colector_platform" / "futures_data_collector.py")
)
# FastAPI 服务配置
HOST = os.getenv("BUFFER_HOST", "0.0.0.0")
PORT = int(os.getenv("BUFFER_PORT", "8600"))
# 数据缓存
CACHE_TTL_SECONDS = int(os.getenv("CACHE_TTL", "300")) # 默认5分钟过期
# 并发采集
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2"))
# 日志
LOG_LEVEL = os.getenv("BUFFER_LOG_LEVEL", "INFO")
# 调度器
SCHEDULER_MAX_INSTANCES = 1 # 同一任务不允许重叠执行

@ -0,0 +1,28 @@
"""
数据缓冲平台 - 数据库连接
"""
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from app.config import DB_PATH
# 确保数据目录存在
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
engine = create_engine(
f"sqlite:///{DB_PATH}",
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
"""获取数据库会话"""
db = SessionLocal()
try:
yield db
finally:
db.close()

@ -0,0 +1,108 @@
"""
数据缓冲平台 - FastAPI 主入口
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from app.database import engine, Base
from app.config import HOST, PORT, LOG_LEVEL
from app.api import data, tasks, config
from app.services.scheduler import start_scheduler, stop_scheduler
# 配置日志
logging.basicConfig(
level=getattr(logging, LOG_LEVEL.upper(), logging.INFO),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时:建表 + 启动调度器
logger.info("创建数据库表...")
Base.metadata.create_all(bind=engine)
logger.info("启动定时调度器...")
start_scheduler()
# 恢复已启用的任务
from app.database import SessionLocal
from app.services.cache import list_tasks
from app.services.scheduler import add_job
db = SessionLocal()
try:
enabled_tasks = [t for t in list_tasks(db) if t.enabled]
for t in enabled_tasks:
add_job(t.id, t.interval_seconds)
logger.info(f"恢复定时任务: {t.symbol} (每 {t.interval_seconds}s)")
finally:
db.close()
logger.info(f"数据缓冲平台已启动 http://{HOST}:{PORT}")
yield
# 关闭时
logger.info("停止调度器...")
stop_scheduler()
app = FastAPI(
title="数据缓冲平台",
description="期货/股票行情数据缓存与定时采集平台",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 静态文件服务
STATIC_DIR = Path(__file__).resolve().parent / "static"
STATIC_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.get("/ui")
def ui_page():
"""品种配置管理页面"""
return FileResponse(str(STATIC_DIR / "index.html"))
# 注册路由
app.include_router(data.router, prefix="/api/v1")
app.include_router(tasks.router, prefix="/api/v1")
app.include_router(config.router, prefix="/api/v1")
@app.get("/api/v1/health")
def health():
return {"status": "ok", "service": "market-data-buffer"}
@app.get("/")
def root():
return {
"message": "数据缓冲平台 API",
"docs": "/docs",
"health": "/api/v1/health",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host=HOST, port=PORT, reload=True)

@ -0,0 +1,52 @@
"""
数据缓冲平台 - 数据模型 (SQLAlchemy ORM)
"""
from datetime import datetime
from sqlalchemy import Column, String, Integer, Float, Text, DateTime, Boolean, Index, UniqueConstraint
from app.database import Base
class MarketData(Base):
"""缓存的市场数据表"""
__tablename__ = "market_data"
id = Column(Integer, primary_key=True, autoincrement=True)
symbol = Column(String(32), nullable=False, index=True, comment="品种合约代码")
data_type = Column(String(16), nullable=False, default="futures", comment="数据类型: futures/stock")
period = Column(String(16), nullable=False, index=True, comment="周期: 5min/15min/30min/60min/daily")
# K线数据以 JSON 字符串形式存储
candles_json = Column(Text, nullable=False, comment="K线数据JSON")
current_price = Column(Float, nullable=True, comment="当前价格")
fetched_at = Column(DateTime, nullable=False, default=datetime.now, index=True, comment="获取时间")
candle_count = Column(Integer, default=0, comment="K线数量")
__table_args__ = (
UniqueConstraint("symbol", "data_type", "period", name="uq_symbol_period"),
)
def __repr__(self):
return f"<MarketData {self.symbol} {self.period} candles={self.candle_count}>"
class ScheduledTask(Base):
"""定时任务配置表"""
__tablename__ = "scheduled_tasks"
id = Column(Integer, primary_key=True, autoincrement=True)
symbol = Column(String(32), nullable=False, comment="品种合约代码")
data_type = Column(String(16), nullable=False, default="futures", comment="数据类型")
periods = Column(String(256), nullable=False, comment="周期列表(逗号分隔), 如 5min,15min,60min")
interval_seconds = Column(Integer, nullable=False, default=300, comment="轮询间隔(秒)")
enabled = Column(Boolean, nullable=False, default=True, comment="是否启用")
job_id = Column(String(64), nullable=True, unique=True, comment="APScheduler job_id")
last_run = Column(DateTime, nullable=True, comment="最后执行时间")
last_status = Column(String(16), nullable=True, comment="最后状态: success/failed")
created_at = Column(DateTime, nullable=False, default=datetime.now)
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now)
__table_args__ = (
UniqueConstraint("symbol", "data_type", name="uq_task_symbol"),
)
def __repr__(self):
return f"<Task {self.symbol} every {self.interval_seconds}s enabled={self.enabled}>"

@ -0,0 +1,102 @@
"""Pydantic 数据校验模型"""
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
# ===== 采集请求 =====
class BatchFetchRequest(BaseModel):
"""批量获取数据请求"""
symbols: List[str] = Field(..., description="品种合约列表,如 ['SN2504', 'AG2506']")
data_type: str = Field(default="futures", description="数据类型: futures / stock")
periods: List[str] = Field(
default=["5min", "15min", "30min", "60min", "daily"],
description="周期列表: 5min / 15min / 30min / 60min / daily"
)
# ===== 数据响应 =====
class CandleItem(BaseModel):
"""单根K线"""
datetime: str
open: float
high: float
low: float
close: float
volume: float
class TimeframeData(BaseModel):
"""一个周期的数据"""
period: str
candles: List[CandleItem]
candle_count: int
fetched_at: str
class SymbolDataResponse(BaseModel):
"""单个品种的数据响应"""
symbol: str
data_type: str
current_price: Optional[float] = None
timeframes: List[TimeframeData]
source: str = "cache|live"
class BatchFetchResponse(BaseModel):
"""批量获取响应"""
success: List[str] = Field(default_factory=list, description="成功的品种")
failed: List[str] = Field(default_factory=list, description="失败的品种")
details: dict = Field(default_factory=dict, description="每个品种的详细数据")
class LatestDataResponse(BaseModel):
"""获取最新数据响应"""
symbol: str
data_type: str
period: str
candles: List[CandleItem]
candle_count: int
current_price: Optional[float] = None
fetched_at: str
is_fresh: bool = Field(description="数据是否在缓存有效期内")
# ===== 定时任务 =====
class CreateTaskRequest(BaseModel):
"""创建定时任务请求"""
symbol: str = Field(..., description="品种合约代码")
data_type: str = Field(default="futures", description="数据类型")
periods: List[str] = Field(
default=["5min", "15min", "30min", "60min", "daily"],
description="需要定时获取的周期"
)
interval_seconds: int = Field(
default=300,
ge=30,
le=86400,
description="轮询间隔(秒),范围 30~86400"
)
class TaskInfo(BaseModel):
"""任务信息"""
id: int
symbol: str
data_type: str
periods: List[str]
interval_seconds: int
enabled: bool
running: bool = Field(description="当前是否正在运行")
last_run: Optional[str] = None
last_status: Optional[str] = None
created_at: str
updated_at: str
class TaskListResponse(BaseModel):
tasks: List[TaskInfo]
total: int

@ -0,0 +1,259 @@
"""
缓存服务 - SQLite 数据库操作
"""
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from app.models import MarketData, ScheduledTask
from app.config import CACHE_TTL_SECONDS
logger = logging.getLogger(__name__)
# ===== 市场数据缓存 =====
def is_cache_valid(
db: Session,
symbol: str,
data_type: str,
period: str,
ttl_seconds: int = CACHE_TTL_SECONDS,
) -> bool:
"""检查指定品种+周期的缓存是否在有效期内"""
record = db.query(MarketData).filter_by(
symbol=symbol,
data_type=data_type,
period=period,
).first()
if not record:
return False
age = (datetime.now() - record.fetched_at).total_seconds()
return age < ttl_seconds
def check_cache_status(
db: Session,
symbol: str,
data_type: str,
periods: List[str],
ttl_seconds: int = CACHE_TTL_SECONDS,
) -> dict:
"""
检查一组周期的缓存状态
Returns:
{
"all_valid": bool, # 所有周期都有有效缓存
"valid_periods": [...],
"missing_periods": [...],
}
"""
valid = []
missing = []
for p in periods:
if is_cache_valid(db, symbol, data_type, p, ttl_seconds):
valid.append(p)
else:
missing.append(p)
return {
"all_valid": len(missing) == 0,
"valid_periods": valid,
"missing_periods": missing,
}
def save_market_data(db: Session, symbol: str, data: Dict) -> MarketData:
"""
保存采集结果到缓存
Args:
symbol: 品种代码
data: 采集脚本返回的完整数据
Returns:
保存的 MarketData 记录
"""
now = datetime.now()
# 按 period 拆分存储(每个周期一条记录)
for period, candles in data.get("timeframes", {}).items():
record = db.query(MarketData).filter_by(
symbol=symbol,
data_type=data.get("type", "futures"),
period=period,
).first()
candles_json = json.dumps(candles, ensure_ascii=False)
if record:
record.candles_json = candles_json
record.current_price = data.get("current_price")
record.fetched_at = now
record.candle_count = len(candles)
else:
record = MarketData(
symbol=symbol,
data_type=data.get("type", "futures"),
period=period,
candles_json=candles_json,
current_price=data.get("current_price"),
fetched_at=now,
candle_count=len(candles),
)
db.add(record)
db.commit()
logger.info(f"缓存已更新: {symbol}, {len(data.get('timeframes', {}))} 个周期")
# 返回最新的一条作为代表
return db.query(MarketData).filter_by(
symbol=symbol,
data_type=data.get("type", "futures"),
).order_by(MarketData.fetched_at.desc()).first()
def get_latest_cached(
db: Session,
symbol: str,
data_type: str = "futures",
period: Optional[str] = None,
) -> List[MarketData]:
"""获取最新缓存数据"""
query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type)
if period:
query = query.filter_by(period=period)
return query.order_by(MarketData.fetched_at.desc()).all()
def get_cached_data(
db: Session,
symbol: str,
data_type: str = "futures",
periods: Optional[List[str]] = None,
) -> Optional[Dict]:
"""
从缓存中获取完整的多周期数据
Returns:
与采集脚本相同格式的数据 None
"""
query = db.query(MarketData).filter_by(symbol=symbol, data_type=data_type)
if periods:
query = query.filter(MarketData.period.in_(periods))
records = query.all()
if not records:
return None
# 检查缓存是否过期
now = datetime.now()
newest = max(r.fetched_at for r in records)
is_fresh = (now - newest).total_seconds() < CACHE_TTL_SECONDS
timeframes = {}
current_price = None
for r in records:
timeframes[r.period] = json.loads(r.candles_json)
if current_price is None:
current_price = r.current_price
return {
"symbol": symbol,
"type": data_type,
"current_price": current_price,
"timestamp": newest.isoformat(),
"timeframes": timeframes,
"is_fresh": is_fresh,
"fetched_at": newest.isoformat(),
}
# ===== 定时任务管理 =====
def create_task(
db: Session,
symbol: str,
data_type: str,
periods: List[str],
interval_seconds: int,
) -> ScheduledTask:
"""创建定时任务配置"""
existing = db.query(ScheduledTask).filter_by(
symbol=symbol, data_type=data_type
).first()
if existing:
existing.periods = ",".join(periods)
existing.interval_seconds = interval_seconds
existing.enabled = True
existing.updated_at = datetime.now()
db.commit()
db.refresh(existing)
return existing
task = ScheduledTask(
symbol=symbol,
data_type=data_type,
periods=",".join(periods),
interval_seconds=interval_seconds,
enabled=True,
)
db.add(task)
db.commit()
db.refresh(task)
return task
def list_tasks(db: Session) -> List[ScheduledTask]:
"""列出所有任务"""
return db.query(ScheduledTask).order_by(ScheduledTask.created_at.desc()).all()
def get_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
"""获取单个任务"""
return db.query(ScheduledTask).filter_by(id=task_id).first()
def disable_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
"""禁用任务"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
task.enabled = False
task.updated_at = datetime.now()
db.commit()
db.refresh(task)
return task
def enable_task(db: Session, task_id: int) -> Optional[ScheduledTask]:
"""启用任务"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
task.enabled = True
task.updated_at = datetime.now()
db.commit()
db.refresh(task)
return task
def delete_task(db: Session, task_id: int) -> bool:
"""删除任务"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
db.delete(task)
db.commit()
return True
return False
def update_task_status(
db: Session, task_id: int, status: str
) -> None:
"""更新任务执行状态"""
task = db.query(ScheduledTask).filter_by(id=task_id).first()
if task:
task.last_run = datetime.now()
task.last_status = status
db.commit()

@ -0,0 +1,82 @@
"""
数据采集服务 - 包装原始采集脚本
"""
import json
import logging
import sys
import os
from datetime import datetime
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# 获取原始采集脚本路径 (buffer_platform/app/services -> buffer_platform -> parent = market_data_colector_platform)
SCRIPT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
if SCRIPT_DIR not in sys.path:
sys.path.insert(0, SCRIPT_DIR)
logger.info(f"已添加采集脚本路径到sys.path: {SCRIPT_DIR}")
def fetch_symbol_data(
symbol: str,
data_type: str = "futures",
periods: Optional[List[str]] = None,
max_workers: int = 2,
) -> Dict:
"""
获取单个品种的多周期数据
返回格式:
{
"symbol": "SN2504",
"type": "futures",
"current_price": 12345.0,
"timestamp": "2025-01-15T10:30:00+08:00",
"timeframes": {
"5min": [{"datetime": ..., "open": ..., ...}, ...],
...
}
}
"""
try:
from futures_data_collector import collect_futures_data, collect_stock_data
if data_type == "stock":
result = collect_stock_data(symbol)
else:
result = collect_futures_data(symbol)
# 如果指定了周期,只保留需要的
if periods:
filtered = {}
for p in periods:
if p in result.get("timeframes", {}):
filtered[p] = result["timeframes"][p]
result["timeframes"] = filtered
return result
except Exception as e:
logger.error(f"采集 {symbol} 数据失败: {e}")
return {
"symbol": symbol,
"type": data_type,
"current_price": None,
"timestamp": datetime.now().isoformat(),
"timeframes": {},
"error": str(e),
}
def fetch_batch(
symbols: List[str],
data_type: str = "futures",
periods: Optional[List[str]] = None,
max_workers: int = 2,
) -> Dict[str, Dict]:
"""批量获取多个品种数据(串行,避免过度并发)"""
results = {}
for sym in symbols:
logger.info(f"开始采集 {sym} ...")
results[sym] = fetch_symbol_data(sym, data_type, periods, max_workers)
return results

@ -0,0 +1,138 @@
"""
调度服务 - APScheduler 管理定时采集任务
"""
import logging
from datetime import datetime
from typing import Dict, Optional
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.executors.pool import ThreadPoolExecutor
from sqlalchemy.orm import Session
from app.database import SessionLocal
from app.services.collector import fetch_symbol_data
from app.services.cache import save_market_data, update_task_status
from app.config import SCHEDULER_MAX_INSTANCES, MAX_WORKERS
logger = logging.getLogger(__name__)
scheduler = BackgroundScheduler(
executors={"default": ThreadPoolExecutor(max_workers=MAX_WORKERS)},
job_defaults={
"max_instances": SCHEDULER_MAX_INSTANCES,
"misfire_grace_time": 60,
},
)
def job_handler(task_id: int):
"""
定时任务的执行函数
每个任务独立创建 DB session避免跨线程问题
"""
db: Session = SessionLocal()
try:
from app.services.cache import get_task
task = get_task(db, task_id)
if not task or not task.enabled:
logger.warning(f"任务 {task_id} 不存在或已禁用,停止执行")
return
periods = task.periods.split(",") if task.periods else []
logger.info(f"[定时任务] 开始采集 {task.symbol} (periods={periods})")
result = fetch_symbol_data(
symbol=task.symbol,
data_type=task.data_type,
periods=periods,
max_workers=MAX_WORKERS,
)
if result.get("timeframes"):
save_market_data(db, task.symbol, result)
update_task_status(db, task_id, "success")
logger.info(f"[定时任务] {task.symbol} 采集成功")
else:
update_task_status(db, task_id, "failed")
logger.error(f"[定时任务] {task.symbol} 采集失败: {result.get('error')}")
except Exception as e:
logger.error(f"[定时任务] 执行异常 task_id={task_id}: {e}")
try:
update_task_status(db, task_id, "failed")
except Exception:
pass
finally:
db.close()
def start_scheduler():
"""启动调度器"""
if not scheduler.running:
scheduler.start()
logger.info("调度器已启动")
def stop_scheduler():
"""停止调度器"""
if scheduler.running:
scheduler.shutdown(wait=False)
logger.info("调度器已停止")
def add_job(task_id: int, interval_seconds: int) -> str:
"""
添加定时任务到调度器
Returns:
job_id
"""
job_id = f"task_{task_id}"
# 如果已存在,先移除
if scheduler.get_job(job_id):
scheduler.remove_job(job_id)
scheduler.add_job(
func=job_handler,
trigger=IntervalTrigger(seconds=interval_seconds),
args=[task_id],
id=job_id,
name=f"auto_collect_{task_id}",
replace_existing=True,
)
logger.info(f"已添加定时任务: job_id={job_id}, interval={interval_seconds}s")
return job_id
def remove_job(task_id: int) -> bool:
"""移除定时任务"""
job_id = f"task_{task_id}"
job = scheduler.get_job(job_id)
if job:
scheduler.remove_job(job_id)
logger.info(f"已移除定时任务: {job_id}")
return True
return False
def is_job_running(task_id: int) -> bool:
"""检查任务是否正在调度器中运行"""
job_id = f"task_{task_id}"
return scheduler.get_job(job_id) is not None
def get_all_jobs() -> Dict[str, dict]:
"""获取所有活跃任务信息"""
jobs = scheduler.get_jobs()
result = {}
for job in jobs:
nrt = getattr(job, 'next_run_time', None)
result[job.id] = {
"name": job.name,
"next_run_time": nrt.isoformat() if nrt else None,
"trigger": str(job.trigger),
}
return result

File diff suppressed because it is too large Load Diff

@ -0,0 +1,36 @@
{
"futures": {
"原油": "SC2606",
"燃油": "FU2606",
"低硫燃油": "LU2607",
"沪银": "AG2606",
"沪金": "AU2606",
"沪铜": "CU2606",
"沪镍": "NI2606",
"沪锡": "SN2606",
"沪铝": "AL2606",
"沪锌": "PB2606",
"氧化铝": "AO2609",
"工业硅": "SI2609",
"多晶硅": "PS2606",
"碳酸锂": "LC2609",
"纯碱": "SA2609",
"烧碱": "SH2607",
"玻璃": "FG2609",
"橡胶": "RU2609",
"合成橡胶": "BR2606",
"20号胶": "NR2607",
"螺纹钢": "RB2610",
"铁矿石": "I2609",
"焦煤": "JM2606",
"焦炭": "J2606",
"PTA": "TA2609",
"棕榈油": "P2609",
"豆粕": "M2609",
"白糖": "SR2609",
"棉花": "CF2609",
"甲醇": "MA2609",
"尿素": "UR2609",
"中证1000": "IM2606"
}
}

Binary file not shown.

@ -0,0 +1,11 @@
fastapi>=0.110.0
uvicorn>=0.29.0
sqlalchemy>=2.0.0
aiosqlite>=0.20.0
pydantic>=2.0.0
apscheduler>=3.10.0
akshare>=1.14.0
pandas>=2.0.0
tenacity>=8.2.0
requests>=2.31.0
httpx>=0.27.0
Loading…
Cancel
Save