commit
7cf4848f81
@ -0,0 +1,76 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DataSourceAdapter 数据源适配器接口
|
||||
type DataSourceAdapter interface {
|
||||
// Connect 建立连接
|
||||
Connect(config map[string]string) error
|
||||
|
||||
// SubscribeTicks 订阅实时Tick
|
||||
SubscribeTicks(symbols []string, callback TickCallback) error
|
||||
|
||||
// FetchKLines 拉取历史K线
|
||||
FetchKLines(symbol, start, end, freq string) ([]KLineData, error)
|
||||
|
||||
// FetchSymbols 获取标的列表
|
||||
FetchSymbols(assetType string) ([]SymbolInfo, error)
|
||||
|
||||
// FetchTradingCalendar 获取交易日历
|
||||
FetchTradingCalendar(exchange, start, end string) ([]TradeCalData, error)
|
||||
|
||||
// HealthCheck 健康检查
|
||||
HealthCheck() error
|
||||
|
||||
// Close 关闭连接
|
||||
Close() error
|
||||
}
|
||||
|
||||
// TickCallback Tick数据回调
|
||||
type TickCallback func(symbol string, tick TickData)
|
||||
|
||||
// TickData Tick数据
|
||||
type TickData struct {
|
||||
Symbol string
|
||||
Price float64
|
||||
Volume int64
|
||||
Time int64
|
||||
}
|
||||
|
||||
// KLineData K线数据
|
||||
type KLineData struct {
|
||||
Symbol string
|
||||
Time int64
|
||||
Open float64
|
||||
High float64
|
||||
Low float64
|
||||
Close float64
|
||||
Volume int64
|
||||
Amount float64
|
||||
OpenInterest int64
|
||||
}
|
||||
|
||||
// SymbolInfo 标的信息
|
||||
type SymbolInfo struct {
|
||||
SymbolID string
|
||||
Name string
|
||||
Exchange string
|
||||
Underlying string
|
||||
ContractMonth string
|
||||
ListDate string
|
||||
DelistDate string
|
||||
}
|
||||
|
||||
// TradeCalData 交易日历数据
|
||||
type TradeCalData struct {
|
||||
Date time.Time
|
||||
IsTradingDay bool
|
||||
HasNightSession bool
|
||||
}
|
||||
|
||||
// AdapterFactory 适配器工厂
|
||||
type AdapterFactory interface {
|
||||
Create(name string) (DataSourceAdapter, error)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,254 @@
|
||||
// Package api 管理后台相关类型定义
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// 配置管理类型
|
||||
// ============================================
|
||||
|
||||
// ConfigType 配置类型
|
||||
type ConfigType string
|
||||
|
||||
const (
|
||||
ConfigTypeServer ConfigType = "server" // 服务器配置
|
||||
ConfigTypeDatabase ConfigType = "database" // 数据库配置
|
||||
ConfigTypeRedis ConfigType = "redis" // Redis配置
|
||||
ConfigTypeSource ConfigType = "source" // 数据源配置
|
||||
ConfigTypeMonitor ConfigType = "monitor" // 监控配置
|
||||
ConfigTypeLog ConfigType = "log" // 日志配置
|
||||
)
|
||||
|
||||
// ConfigItem 配置项
|
||||
type ConfigItem struct {
|
||||
Key string `json:"key"` // 配置键
|
||||
Value interface{} `json:"value"` // 配置值
|
||||
Type string `json:"type"` // 值类型: string/int/bool/json
|
||||
Description string `json:"description"` // 配置说明
|
||||
Editable bool `json:"editable"` // 是否可编辑
|
||||
Required bool `json:"required"` // 是否必填
|
||||
}
|
||||
|
||||
// ConfigSection 配置分组
|
||||
type ConfigSection struct {
|
||||
Name string `json:"name"` // 分组名称
|
||||
Type ConfigType `json:"type"` // 分组类型
|
||||
Description string `json:"description"` // 分组说明
|
||||
Items []ConfigItem `json:"items"` // 配置项列表
|
||||
}
|
||||
|
||||
// ConfigListRequest 获取配置列表请求
|
||||
type ConfigListRequest struct {
|
||||
Type ConfigType `json:"type" form:"type"` // 配置类型筛选
|
||||
}
|
||||
|
||||
// ConfigListData 配置列表响应
|
||||
type ConfigListData struct {
|
||||
Sections []ConfigSection `json:"sections"` // 配置分组列表
|
||||
Version string `json:"version"` // 配置版本
|
||||
Updated time.Time `json:"updated"` // 最后更新时间
|
||||
}
|
||||
|
||||
// ConfigUpdateRequest 更新配置请求
|
||||
type ConfigUpdateRequest struct {
|
||||
Type ConfigType `json:"type" validate:"required"` // 配置类型
|
||||
Items map[string]interface{} `json:"items" validate:"required"` // 更新的配置项
|
||||
}
|
||||
|
||||
// ConfigUpdateData 更新配置响应
|
||||
type ConfigUpdateData struct {
|
||||
Success bool `json:"success"` // 是否成功
|
||||
NeedRestart bool `json:"need_restart"` // 是否需要重启
|
||||
Message string `json:"message"` // 提示信息
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 适配器管理类型
|
||||
// ============================================
|
||||
|
||||
// AdapterInfo 适配器信息
|
||||
type AdapterInfo struct {
|
||||
Name string `json:"name"` // 适配器名称
|
||||
Type string `json:"type"` // 适配器类型
|
||||
Version string `json:"version"` // 版本
|
||||
Description string `json:"description"` // 描述
|
||||
Status AdapterStatus `json:"status"` // 状态
|
||||
Config map[string]string `json:"config"` // 当前配置
|
||||
LastError string `json:"last_error,omitempty"` // 最后错误
|
||||
UpdatedAt time.Time `json:"updated_at"` // 更新时间
|
||||
}
|
||||
|
||||
// AdapterStatus 适配器状态
|
||||
type AdapterStatus string
|
||||
|
||||
const (
|
||||
AdapterStatusActive AdapterStatus = "active" // 已激活
|
||||
AdapterStatusStandby AdapterStatus = "standby" // 待命
|
||||
AdapterStatusDisabled AdapterStatus = "disabled" // 已禁用
|
||||
AdapterStatusError AdapterStatus = "error" // 错误
|
||||
)
|
||||
|
||||
// AdapterListData 适配器列表响应
|
||||
type AdapterListData struct {
|
||||
Adapters []AdapterInfo `json:"adapters"` // 适配器列表
|
||||
}
|
||||
|
||||
// AdapterToggleRequest 启用/禁用适配器请求
|
||||
type AdapterToggleRequest struct {
|
||||
Name string `json:"name" validate:"required"` // 适配器名称
|
||||
Enable bool `json:"enable"` // 是否启用
|
||||
}
|
||||
|
||||
// AdapterConfigUpdateRequest 更新适配器配置请求
|
||||
type AdapterConfigUpdateRequest struct {
|
||||
Name string `json:"name" validate:"required"` // 适配器名称
|
||||
Config map[string]string `json:"config" validate:"required"` // 配置
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 系统管理类型
|
||||
// ============================================
|
||||
|
||||
// SystemStatusData 系统状态数据
|
||||
type SystemStatusData struct {
|
||||
Status string `json:"status"` // 系统状态: running/stopping/restarting
|
||||
Version string `json:"version"` // 系统版本
|
||||
StartTime time.Time `json:"start_time"` // 启动时间
|
||||
Uptime string `json:"uptime"` // 运行时长
|
||||
GoVersion string `json:"go_version"` // Go版本
|
||||
MemoryUsage MemoryInfo `json:"memory"` // 内存使用
|
||||
Goroutines int `json:"goroutines"` // Goroutine数量
|
||||
}
|
||||
|
||||
// MemoryInfo 内存信息
|
||||
type MemoryInfo struct {
|
||||
Alloc uint64 `json:"alloc"` // 已分配内存
|
||||
TotalAlloc uint64 `json:"total_alloc"` // 累计分配
|
||||
Sys uint64 `json:"sys"` // 系统内存
|
||||
NumGC uint32 `json:"num_gc"` // GC次数
|
||||
}
|
||||
|
||||
// RestartRequest 重启服务请求
|
||||
type RestartRequest struct {
|
||||
Force bool `json:"force"` // 是否强制重启
|
||||
}
|
||||
|
||||
// ReloadRequest 热加载配置请求
|
||||
type ReloadRequest struct {
|
||||
ConfigType ConfigType `json:"config_type"` // 指定加载的配置类型,空表示全部
|
||||
}
|
||||
|
||||
// ReloadData 热加载响应
|
||||
type ReloadData struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 接口测试类型
|
||||
// ============================================
|
||||
|
||||
// APITestCase 接口测试用例
|
||||
type APITestCase struct {
|
||||
ID string `json:"id"` // 用例ID
|
||||
Name string `json:"name"` // 用例名称
|
||||
Method string `json:"method"` // HTTP方法
|
||||
Path string `json:"path"` // 请求路径
|
||||
Description string `json:"description"` // 描述
|
||||
Params map[string]string `json:"params"` // 默认参数
|
||||
Body interface{} `json:"body"` // 请求体
|
||||
}
|
||||
|
||||
// APITestCategory 测试分类
|
||||
type APITestCategory struct {
|
||||
Name string `json:"name"` // 分类名称
|
||||
Items []APITestCase `json:"items"` // 测试用例
|
||||
}
|
||||
|
||||
// APITestListData 接口测试列表响应
|
||||
type APITestListData struct {
|
||||
Categories []APITestCategory `json:"categories"` // 分类列表
|
||||
BaseURL string `json:"base_url"` // 基础URL
|
||||
}
|
||||
|
||||
// APITestRequest 执行接口测试请求
|
||||
type APITestRequest struct {
|
||||
ID string `json:"id" validate:"required"` // 用例ID
|
||||
Params map[string]string `json:"params"` // 自定义参数
|
||||
Body interface{} `json:"body"` // 自定义请求体
|
||||
}
|
||||
|
||||
// APITestResult 接口测试结果
|
||||
type APITestResult struct {
|
||||
ID int `json:"id"` // 测试ID
|
||||
CaseID string `json:"case_id"` // 用例ID
|
||||
Name string `json:"name"` // 用例名称
|
||||
Success bool `json:"success"` // 是否成功
|
||||
StatusCode int `json:"status_code"` // HTTP状态码
|
||||
Latency int64 `json:"latency"` // 延迟(ms)
|
||||
Request interface{} `json:"request"` // 请求信息
|
||||
Response interface{} `json:"response"` // 响应信息
|
||||
Error string `json:"error,omitempty"` // 错误信息
|
||||
Timestamp time.Time `json:"timestamp"` // 测试时间
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// WebSocket测试类型
|
||||
// ============================================
|
||||
|
||||
// WSTestCase WebSocket测试用例
|
||||
type WSTestCase struct {
|
||||
ID string `json:"id"` // 用例ID
|
||||
Name string `json:"name"` // 用例名称
|
||||
Description string `json:"description"` // 描述
|
||||
Action string `json:"action"` // 动作类型
|
||||
Symbols []string `json:"symbols"` // 订阅标的
|
||||
}
|
||||
|
||||
// WSTestListData WebSocket测试列表响应
|
||||
type WSTestListData struct {
|
||||
Cases []WSTestCase `json:"cases"` // 测试用例
|
||||
WSURL string `json:"ws_url"` // WebSocket地址
|
||||
}
|
||||
|
||||
// WSTestRequest WebSocket测试请求
|
||||
type WSTestRequest struct {
|
||||
ID string `json:"id" validate:"required"` // 用例ID
|
||||
Symbols []string `json:"symbols"` // 自定义标的
|
||||
}
|
||||
|
||||
// WSTestResult WebSocket测试结果
|
||||
type WSTestResult struct {
|
||||
ID string `json:"id"` // 测试ID
|
||||
CaseID string `json:"case_id"` // 用例ID
|
||||
Success bool `json:"success"` // 是否成功
|
||||
Latency int64 `json:"latency"` // 连接延迟(ms)
|
||||
Messages []WSMessage `json:"messages"` // 收到的消息
|
||||
Error string `json:"error,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// WSMessage WebSocket消息
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"` // 消息类型
|
||||
Data interface{} `json:"data"` // 消息内容
|
||||
Timestamp time.Time `json:"timestamp"` // 时间
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 测试历史记录类型
|
||||
// ============================================
|
||||
|
||||
// TestHistoryRequest 获取测试历史请求
|
||||
type TestHistoryRequest struct {
|
||||
Type string `json:"type" form:"type"` // 测试类型: api/ws
|
||||
Limit int `json:"limit" form:"limit"` // 数量限制
|
||||
}
|
||||
|
||||
// TestHistoryData 测试历史数据
|
||||
type TestHistoryData struct {
|
||||
APITests []APITestResult `json:"api_tests"` // API测试历史
|
||||
WSTests []WSTestResult `json:"ws_tests"` // WebSocket测试历史
|
||||
}
|
||||
@ -0,0 +1,418 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
)
|
||||
|
||||
// Router API路由注册
|
||||
type Router struct {
|
||||
handler Handler
|
||||
}
|
||||
|
||||
// NewRouter 创建路由
|
||||
func NewRouter(handler Handler) *Router {
|
||||
return &Router{handler: handler}
|
||||
}
|
||||
|
||||
// Register 注册所有路由
|
||||
func (r *Router) Register(engine *gin.Engine) {
|
||||
// 公开接口(无需认证)
|
||||
public := engine.Group("/v1")
|
||||
{
|
||||
public.GET("/admin/health", r.healthCheck)
|
||||
}
|
||||
|
||||
// 需要认证的接口
|
||||
api := engine.Group("/v1")
|
||||
api.Use(r.authMiddleware())
|
||||
{
|
||||
// 股票接口
|
||||
stock := api.Group("/stock")
|
||||
{
|
||||
stock.GET("/klines/:symbol", r.queryStockKLines)
|
||||
stock.GET("/symbols", r.listStockSymbols)
|
||||
stock.POST("/klines/batch", r.batchQueryStockKLines)
|
||||
stock.GET("/trading-dates", r.getStockTradingDates)
|
||||
}
|
||||
|
||||
// 期货接口
|
||||
futures := api.Group("/futures")
|
||||
{
|
||||
futures.GET("/klines/:symbol", r.queryFuturesKLines)
|
||||
futures.GET("/symbols", r.listFuturesSymbols)
|
||||
futures.POST("/klines/batch", r.batchQueryFuturesKLines)
|
||||
futures.GET("/continuous/:underlying", r.queryContinuousKLines)
|
||||
futures.GET("/trading-dates", r.getFuturesTradingDates)
|
||||
futures.GET("/contracts", r.getFuturesContracts)
|
||||
}
|
||||
|
||||
// 管理接口
|
||||
admin := api.Group("/admin")
|
||||
{
|
||||
admin.GET("/source/status", r.getDataSourceStatus)
|
||||
admin.POST("/source/switch", r.switchDataSource)
|
||||
admin.POST("/backfill", r.backfillData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// authMiddleware API认证中间件
|
||||
func (r *Router) authMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKey := c.GetHeader("X-API-Key")
|
||||
if apiKey == "" {
|
||||
c.JSON(http.StatusUnauthorized, ErrorResponse{
|
||||
Code: 401,
|
||||
Message: "缺少API Key",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 验证API Key有效性
|
||||
// 可以将用户信息存入context
|
||||
c.Set("api_key", apiKey)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 股票接口实现
|
||||
// ============================================
|
||||
|
||||
func (r *Router) queryStockKLines(c *gin.Context) {
|
||||
var req KLineQueryRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req.Symbol = c.Param("symbol")
|
||||
|
||||
resp, err := r.handler.QueryKLines(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) listStockSymbols(c *gin.Context) {
|
||||
var req SymbolListRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.ListSymbols(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) batchQueryStockKLines(c *gin.Context) {
|
||||
var req BatchKLineRequest
|
||||
if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.BatchQueryKLines(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 期货接口实现
|
||||
// ============================================
|
||||
|
||||
func (r *Router) queryFuturesKLines(c *gin.Context) {
|
||||
var req KLineQueryRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req.Symbol = c.Param("symbol")
|
||||
|
||||
resp, err := r.handler.QueryKLines(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) listFuturesSymbols(c *gin.Context) {
|
||||
var req SymbolListRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.ListSymbols(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) batchQueryFuturesKLines(c *gin.Context) {
|
||||
var req BatchKLineRequest
|
||||
if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.BatchQueryKLines(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) queryContinuousKLines(c *gin.Context) {
|
||||
underlying := c.Param("underlying")
|
||||
|
||||
var req KLineQueryRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.QueryContinuousKLines(c.Request.Context(), underlying, &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 管理接口实现
|
||||
// ============================================
|
||||
|
||||
func (r *Router) getDataSourceStatus(c *gin.Context) {
|
||||
resp, err := r.handler.GetDataSourceStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) switchDataSource(c *gin.Context) {
|
||||
var req SourceSwitchRequest
|
||||
if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.SwitchDataSource(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, ErrorResponse{
|
||||
Code: 422,
|
||||
Message: "数据源切换失败",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) backfillData(c *gin.Context) {
|
||||
var req BackfillRequest
|
||||
if err := c.ShouldBindBodyWith(&req, binding.JSON); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.BackfillData(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "补录任务失败",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusAccepted, resp)
|
||||
}
|
||||
|
||||
func (r *Router) healthCheck(c *gin.Context) {
|
||||
resp, err := r.handler.HealthCheck(c.Request.Context())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, ErrorResponse{
|
||||
Code: 503,
|
||||
Message: "服务不可用",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 新增接口:交易日历和期货合约
|
||||
// ============================================
|
||||
|
||||
func (r *Router) getStockTradingDates(c *gin.Context) {
|
||||
var req TradingDatesRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.GetTradingDates(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) getFuturesTradingDates(c *gin.Context) {
|
||||
var req TradingDatesRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.GetTradingDates(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (r *Router) getFuturesContracts(c *gin.Context) {
|
||||
var req FuturesContractsRequest
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Code: 400,
|
||||
Message: "参数错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.handler.GetContractsByUnderlying(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Code: 500,
|
||||
Message: "服务器内部错误",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@ -0,0 +1,165 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"market-data-service/api"
|
||||
"market-data-service/internal/handler"
|
||||
"market-data-service/internal/monitor"
|
||||
"market-data-service/internal/repository"
|
||||
"market-data-service/internal/service"
|
||||
"market-data-service/internal/websocket"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 配置
|
||||
port := getEnv("PORT", "8080")
|
||||
dbURL := getEnv("DATABASE_URL", "postgres://user:password@localhost:5432/marketdata?sslmode=disable")
|
||||
configPath := getEnv("CONFIG_PATH", "./config.json")
|
||||
|
||||
// 设置运行模式
|
||||
ginMode := getEnv("GIN_MODE", "debug")
|
||||
gin.SetMode(ginMode)
|
||||
|
||||
// 连接数据库
|
||||
db, err := repository.NewDB(dbURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 初始化配置服务
|
||||
configService, err := service.NewConfigService(configPath)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to load config from %s: %v", configPath, err)
|
||||
configService, _ = service.NewConfigService("")
|
||||
}
|
||||
|
||||
// 初始化适配器服务
|
||||
adapterService := service.NewAdapterService()
|
||||
|
||||
// 初始化测试服务
|
||||
testService := service.NewTestService()
|
||||
|
||||
// 初始化Repository
|
||||
stockRepo := repository.NewStockRepository(db)
|
||||
futuresRepo := repository.NewFuturesRepository(db)
|
||||
|
||||
// 初始化Service
|
||||
stockService := service.NewStockService(stockRepo)
|
||||
futuresService := service.NewFuturesService(futuresRepo)
|
||||
adminService := service.NewAdminService(db)
|
||||
|
||||
// 初始化Handler
|
||||
h := handler.NewHandler(stockService, futuresService, adminService)
|
||||
|
||||
// 初始化管理后台Handler
|
||||
adminHandler := handler.NewAdminHandlerImpl(configService, adapterService, testService)
|
||||
|
||||
// 初始化WebSocket Hub
|
||||
hub := websocket.NewHub()
|
||||
go hub.Run()
|
||||
|
||||
// 初始化WebSocket Server
|
||||
wsServer := websocket.NewServer(hub)
|
||||
|
||||
// 初始化数据质量监控
|
||||
alertSender := &monitor.LogAlertSender{}
|
||||
dataMonitor := monitor.NewMonitor(db, stockRepo, futuresRepo, alertSender)
|
||||
|
||||
// 启动每日检查定时任务
|
||||
ctx := context.Background()
|
||||
dataMonitor.StartDailyCheckCron(ctx)
|
||||
|
||||
// 创建Gin引擎
|
||||
router := gin.New()
|
||||
router.Use(gin.Recovery())
|
||||
router.Use(loggerMiddleware())
|
||||
|
||||
// 注册API路由
|
||||
apiRouter := api.NewRouter(h)
|
||||
apiRouter.Register(router)
|
||||
|
||||
// 注册WebSocket路由
|
||||
router.GET("/v1/stream", wsServer.HandleWebSocket)
|
||||
|
||||
// 注册管理后台路由
|
||||
adminRouter := api.NewAdminRouter(h, adminHandler, adminHandler, adminHandler)
|
||||
adminRouter.Register(router)
|
||||
|
||||
// 创建HTTP服务器
|
||||
srv := &http.Server{
|
||||
Addr: ":" + port,
|
||||
Handler: router,
|
||||
}
|
||||
|
||||
// 启动服务器(异步)
|
||||
go func() {
|
||||
log.Printf("Server starting on port %s", port)
|
||||
log.Printf("Admin dashboard: http://localhost:%s/admin", port)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
<-quit
|
||||
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
// 优雅关闭
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server exited")
|
||||
}
|
||||
|
||||
// loggerMiddleware 日志中间件
|
||||
func loggerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
|
||||
start.Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
@ -0,0 +1,255 @@
|
||||
// Package sync 数据同步工具
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"market-data-service/adapter/tushare"
|
||||
"market-data-service/api"
|
||||
"market-data-service/internal/repository"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
syncType = flag.String("type", "", "同步类型: stocks, futures, calendar, klines")
|
||||
startDate = flag.String("start", "", "开始日期 YYYYMMDD")
|
||||
endDate = flag.String("end", "", "结束日期 YYYYMMDD")
|
||||
symbol = flag.String("symbol", "", "标的代码")
|
||||
underlying = flag.String("underlying", "", "期货品种代码")
|
||||
freq = flag.String("freq", "1d", "K线周期: 1m/5m/15m/30m/60m/1d")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *syncType == "" {
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 配置
|
||||
tushareToken := os.Getenv("TUSHARE_TOKEN")
|
||||
if tushareToken == "" {
|
||||
log.Fatal("TUSHARE_TOKEN environment variable is required")
|
||||
}
|
||||
|
||||
dbURL := os.Getenv("DATABASE_URL")
|
||||
if dbURL == "" {
|
||||
dbURL = "postgres://user:password@localhost:5432/marketdata?sslmode=disable"
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
db, err := repository.NewDB(dbURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 初始化Tushare客户端
|
||||
client := tushare.NewClient(tushareToken)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch *syncType {
|
||||
case "stocks":
|
||||
syncStocks(ctx, client, db)
|
||||
case "futures":
|
||||
syncFutures(ctx, client, db)
|
||||
case "calendar":
|
||||
syncCalendar(ctx, client, db, *startDate, *endDate)
|
||||
case "klines":
|
||||
syncKLines(ctx, client, db, *symbol, *startDate, *endDate, *freq)
|
||||
default:
|
||||
log.Fatalf("Unknown sync type: %s", *syncType)
|
||||
}
|
||||
}
|
||||
|
||||
// syncStocks 同步股票基础信息
|
||||
func syncStocks(ctx context.Context, client *tushare.Client, db *repository.DB) {
|
||||
log.Println("Syncing stock basic info...")
|
||||
|
||||
data, err := client.GetStockBasic()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get stock basic: %v", err)
|
||||
}
|
||||
|
||||
repo := repository.NewStockRepository(db)
|
||||
|
||||
symbols := make([]api.Symbol, 0, len(data))
|
||||
for _, d := range data {
|
||||
if d.ListStatus != "L" {
|
||||
continue // 只同步上市状态的
|
||||
}
|
||||
|
||||
listDate, _ := time.Parse("20060102", d.ListDate)
|
||||
|
||||
symbols = append(symbols, api.Symbol{
|
||||
SymbolID: d.TSCode,
|
||||
SymbolType: api.SymbolTypeStock,
|
||||
Exchange: api.Exchange(d.Exchange),
|
||||
Name: d.Name,
|
||||
NameEN: d.EnName,
|
||||
Industry: d.Industry,
|
||||
ListDate: &listDate,
|
||||
Status: "active",
|
||||
})
|
||||
}
|
||||
|
||||
if err := repo.SaveSymbols(ctx, symbols); err != nil {
|
||||
log.Fatalf("Failed to save symbols: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Synced %d stocks", len(symbols))
|
||||
}
|
||||
|
||||
// syncFutures 同步期货基础信息
|
||||
func syncFutures(ctx context.Context, client *tushare.Client, db *repository.DB) {
|
||||
log.Println("Syncing futures basic info...")
|
||||
|
||||
data, err := client.GetFuturesBasic("")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get futures basic: %v", err)
|
||||
}
|
||||
|
||||
repo := repository.NewFuturesRepository(db)
|
||||
|
||||
symbols := make([]api.Symbol, 0, len(data))
|
||||
for _, d := range data {
|
||||
listDate, _ := time.Parse("20060102", d.ListDate)
|
||||
delistDate, _ := time.Parse("20060102", d.DelistDate)
|
||||
|
||||
status := "active"
|
||||
if time.Now().After(delistDate) {
|
||||
status = "expired"
|
||||
}
|
||||
|
||||
symbols = append(symbols, api.Symbol{
|
||||
SymbolID: d.TSCode,
|
||||
SymbolType: api.SymbolTypeFutures,
|
||||
Exchange: api.Exchange(d.Exchange),
|
||||
Name: d.Name,
|
||||
Underlying: d.FutCode,
|
||||
ContractMonth: d.Symbol[len(d.FutCode):],
|
||||
ListDate: &listDate,
|
||||
DelistDate: &delistDate,
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
if err := repo.SaveSymbols(ctx, symbols); err != nil {
|
||||
log.Fatalf("Failed to save symbols: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Synced %d futures", len(symbols))
|
||||
}
|
||||
|
||||
// syncCalendar 同步交易日历
|
||||
func syncCalendar(ctx context.Context, client *tushare.Client, db *repository.DB, start, end string) {
|
||||
if start == "" {
|
||||
start = time.Now().AddDate(0, 0, -30).Format("20060102")
|
||||
}
|
||||
if end == "" {
|
||||
end = time.Now().AddDate(0, 6, 0).Format("20060102")
|
||||
}
|
||||
|
||||
log.Printf("Syncing trading calendar from %s to %s...", start, end)
|
||||
|
||||
// 同步股票交易日历(上交所)
|
||||
stockData, err := client.GetTradeCal("SSE", start, end, -1)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get stock calendar: %v", err)
|
||||
}
|
||||
|
||||
stockRepo := repository.NewStockRepository(db)
|
||||
stockDates := make([]api.TradeCalData, len(stockData))
|
||||
for i, d := range stockData {
|
||||
calDate, _ := time.Parse("20060102", d.CalDate)
|
||||
stockDates[i] = api.TradeCalData{
|
||||
Date: calDate,
|
||||
IsTradingDay: d.IsOpen == 1,
|
||||
}
|
||||
}
|
||||
|
||||
if err := stockRepo.SaveTradingCalendar(ctx, stockDates); err != nil {
|
||||
log.Fatalf("Failed to save stock calendar: %v", err)
|
||||
}
|
||||
|
||||
// 同步期货交易日历(使用相同的,实际可能需要单独配置)
|
||||
futuresRepo := repository.NewFuturesRepository(db)
|
||||
if err := futuresRepo.SaveTradingCalendar(ctx, stockDates); err != nil {
|
||||
log.Fatalf("Failed to save futures calendar: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("Synced %d calendar days", len(stockDates))
|
||||
}
|
||||
|
||||
// syncKLines 同步K线数据
|
||||
func syncKLines(ctx context.Context, client *tushare.Client, db *repository.DB, symbol, start, end, freq string) {
|
||||
if symbol == "" {
|
||||
log.Fatal("symbol is required for klines sync")
|
||||
}
|
||||
if start == "" {
|
||||
start = time.Now().AddDate(0, 0, -7).Format("20060102")
|
||||
}
|
||||
if end == "" {
|
||||
end = time.Now().Format("20060102")
|
||||
}
|
||||
|
||||
log.Printf("Syncing %s klines for %s from %s to %s...", freq, symbol, start, end)
|
||||
|
||||
adapter := tushare.NewAdapter()
|
||||
if err := adapter.Connect(map[string]string{"token": os.Getenv("TUSHARE_TOKEN")}); err != nil {
|
||||
log.Fatalf("Failed to connect adapter: %v", err)
|
||||
}
|
||||
|
||||
data, err := adapter.FetchKLines(symbol, start, end, freq)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to fetch klines: %v", err)
|
||||
}
|
||||
|
||||
// 转换为api.KLineItem并保存
|
||||
items := make([]api.KLineItem, len(data))
|
||||
for i, d := range data {
|
||||
ts := time.Unix(d.Time, 0)
|
||||
items[i] = api.KLineItem{
|
||||
Time: ts,
|
||||
Open: d.Open,
|
||||
High: d.High,
|
||||
Low: d.Low,
|
||||
Close: d.Close,
|
||||
Volume: d.Volume,
|
||||
Amount: d.Amount,
|
||||
}
|
||||
if d.OpenInterest > 0 {
|
||||
oi := d.OpenInterest
|
||||
items[i].OpenInterest = &oi
|
||||
}
|
||||
}
|
||||
|
||||
// 判断股票还是期货并保存
|
||||
if isStock(symbol) {
|
||||
repo := repository.NewStockRepository(db)
|
||||
if err := repo.SaveKLines(ctx, api.Frequency(freq), items); err != nil {
|
||||
log.Fatalf("Failed to save stock klines: %v", err)
|
||||
}
|
||||
} else {
|
||||
repo := repository.NewFuturesRepository(db)
|
||||
if err := repo.SaveKLines(ctx, api.Frequency(freq), symbol, items); err != nil {
|
||||
log.Fatalf("Failed to save futures klines: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Synced %d klines", len(items))
|
||||
}
|
||||
|
||||
// isStock 判断是否为股票代码
|
||||
func isStock(symbol string) bool {
|
||||
return contains(symbol, ".SH") || contains(symbol, ".SZ") || contains(symbol, ".BJ")
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && s[len(s)-len(substr):] == substr
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
{
|
||||
"server": {
|
||||
"port": 8080,
|
||||
"mode": "debug",
|
||||
"api_key": "your-api-key-here"
|
||||
},
|
||||
"database": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "postgres",
|
||||
"password": "password",
|
||||
"database": "marketdata"
|
||||
},
|
||||
"redis": {
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"password": "",
|
||||
"db": 0
|
||||
},
|
||||
"sources": {
|
||||
"stock": {
|
||||
"active": "tushare",
|
||||
"list": {
|
||||
"tushare": {
|
||||
"type": "http",
|
||||
"config": {
|
||||
"token": "your-tushare-token",
|
||||
"base_url": "https://api.tushare.pro"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"futures": {
|
||||
"active": "tushare",
|
||||
"list": {
|
||||
"tushare": {
|
||||
"type": "http",
|
||||
"config": {
|
||||
"token": "your-tushare-token",
|
||||
"base_url": "https://api.tushare.pro"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
{
|
||||
"server": {
|
||||
"port": 8080,
|
||||
"mode": "debug",
|
||||
"api_key": "demo-api-key-2024"
|
||||
},
|
||||
"database": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "postgres",
|
||||
"password": "postgres",
|
||||
"database": "marketdata"
|
||||
},
|
||||
"redis": {
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"password": "",
|
||||
"db": 0
|
||||
},
|
||||
"sources": {
|
||||
"stock": {
|
||||
"active": "tushare",
|
||||
"list": {
|
||||
"tushare": {
|
||||
"type": "http",
|
||||
"config": {
|
||||
"token": "your-tushare-token-here",
|
||||
"base_url": "https://api.tushare.pro"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"futures": {
|
||||
"active": "tushare",
|
||||
"list": {
|
||||
"tushare": {
|
||||
"type": "http",
|
||||
"config": {
|
||||
"token": "your-tushare-token-here",
|
||||
"base_url": "https://api.tushare.pro"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,460 @@
|
||||
# 管理后台架构设计文档
|
||||
|
||||
## 重要说明
|
||||
|
||||
本文档适用于 **Go** 和 **Python** 双实现。两者架构设计保持一致,仅技术栈不同:
|
||||
- **Go**: Gin + 原生SQL + Goroutine
|
||||
- **Python**: FastAPI + SQLAlchemy + asyncio
|
||||
|
||||
---
|
||||
|
||||
## 1. 系统架构图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ 客户端层 │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Web Browser │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
|
||||
│ │ │ Dashboard │ │ Config │ │ Adapter │ │ Test │ │ │
|
||||
│ │ │ Page │ │ Page │ │ Page │ │ Page │ │ │
|
||||
│ │ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────┘ │
|
||||
│ http://localhost:8080/admin │
|
||||
└─────────────────────────────────────┬───────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ 接入层 │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Router (Gin/FastAPI) │ │
|
||||
│ │ │ │
|
||||
│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────────────────┐ │ │
|
||||
│ │ │ API Router │ │ Admin Router │ │ WebSocket Handler │ │ │
|
||||
│ │ │(api/router.go)│ │(api/admin_ │ │ (/v1/stream) │ │ │
|
||||
│ │ │ (routes.py) │ │ router.go) │ │ │ │ │
|
||||
│ │ │ │ │(admin_routes)│ │ │ │ │
|
||||
│ │ └──────────────┘ └──────────────┘ └──────────────────────────┘ │ │
|
||||
│ │ │ │ │ │
|
||||
│ └──────────┼────────────────┼─────────────────────────────────────────┘ │
|
||||
│ │ │ │
|
||||
└─────────────┼────────────────┼─────────────────────────────────────────────┘
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ 业务层 │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Service Layer │ │
|
||||
│ │ │ │
|
||||
│ │ ┌─────────────────┐ ┌─────────────────────────────────┐ │ │
|
||||
│ │ │ Stock/Futures │ │ AdminHandlerImpl │ │ │
|
||||
│ │ │ Services │ │(admin.go / admin_routes.py) │ │ │
|
||||
│ │ └─────────────────┘ └──────────────┬──────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ └──────────────────────────────────────────────┼──────────────────────┘ │
|
||||
│ │ │
|
||||
└─────────────────────────────────────────────────┼──────────────────────────┘
|
||||
│
|
||||
┌───────────────────────────────────┼───────────────────┐
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ 服务层 │
|
||||
│ │
|
||||
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
|
||||
│ │ ConfigService │ │AdapterService │ │ TestService │ │
|
||||
│ │(config.go) │ │(adapter.go) │ │ (test.go) │ │
|
||||
│ │(config_service)│ │(adapter_service)│ │(test_service) │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ • 配置加载 │ │ • 适配器注册 │ │ • API测试 │ │
|
||||
│ │ • 热加载 │ │ • 启用/禁用 │ │ • WS测试 │ │
|
||||
│ │ • 状态监控 │ │ • 配置管理 │ │ • 历史记录 │ │
|
||||
│ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ │
|
||||
│ │ │ │ │
|
||||
│ └────────────────────┼────────────────────┘ │
|
||||
│ │ │
|
||||
└────────────────────────────────┼───────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ 数据层 │
|
||||
│ │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌───────────┐ │
|
||||
│ │ config.json │ │ Database │ │ Adapter │ │ Memory │ │
|
||||
│ │ (文件) │ │ (PostgreSQL)│ │ Factory │ │ Cache │ │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ │ 持久化配置 │ │ 数据源配置 │ │ 适配器实例 │ │ 测试历史 │ │
|
||||
│ └─────────────┘ └─────────────┘ └─────────────┘ └───────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 技术栈对比
|
||||
|
||||
| 层级 | Go实现 | Python实现 |
|
||||
|------|--------|------------|
|
||||
| 接入层 | Gin Router | FastAPI Router |
|
||||
| 业务层 | Go Interfaces | Python Protocols |
|
||||
| 服务层 | Go Structs + Methods | Python Classes |
|
||||
| 数据层 | database/sql | SQLAlchemy ORM |
|
||||
| 配置 | JSON + 自定义解析 | Pydantic Settings |
|
||||
|
||||
---
|
||||
|
||||
## 2. 模块关系图
|
||||
|
||||
### 配置管理模块
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ 配置管理模块 │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ ConfigService │ │
|
||||
│ │ (Go / Python) │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌────────────┐ │ │
|
||||
│ │ │ Load Config │───►│ Store Config│◄───│UpdateConfig│ │ │
|
||||
│ │ │ (JSON File) │ │ (Memory) │ │ (API) │ │ │
|
||||
│ │ └─────────────┘ └──────┬──────┘ └────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────┐ │ │
|
||||
│ │ │ Reload │ │ │
|
||||
│ │ │ (Hot Swap) │ │ │
|
||||
│ │ └──────┬──────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────┐ │ │
|
||||
│ │ │ Callbacks │ │ │
|
||||
│ │ │ (Notify All)│ │ │
|
||||
│ │ └─────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Go实现**:
|
||||
```go
|
||||
// 使用 sync.RWMutex 保证并发安全
|
||||
type ConfigServiceImpl struct {
|
||||
config *config.Config
|
||||
mu sync.RWMutex
|
||||
callbacks map[api.ConfigType][]func()
|
||||
}
|
||||
```
|
||||
|
||||
**Python实现**:
|
||||
```python
|
||||
# 使用 threading.RLock 保证并发安全
|
||||
class ConfigService:
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.lock = threading.RLock()
|
||||
self.callbacks: Dict[ConfigType, List[Callable]] = {}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 适配器管理模块
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ 适配器管理模块 │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ AdapterService │ │
|
||||
│ │ (Go / Python) │ │
|
||||
│ │ ┌──────────────┐ ┌─────────────────────┐ │ │
|
||||
│ │ │ Register │────────►│ Factory Map │ │ │
|
||||
│ │ │ (Add Factory)│ │ [name]factoryFunc │ │ │
|
||||
│ │ └──────────────┘ └─────────────────────┘ │ │
|
||||
│ │ │ │
|
||||
│ │ ┌──────────────┐ ┌─────────────────────┐ │ │
|
||||
│ │ │ Toggle │────────►│ Active Adapters │ │ │
|
||||
│ │ │(Enable/Disable)│ │ [name]adapterInstance│ │ │
|
||||
│ │ └──────────────┘ └─────────────────────┘ │ │
|
||||
│ │ │ │
|
||||
│ │ ┌──────────────┐ ┌─────────────────────┐ │ │
|
||||
│ │ │Update Config │────────►│ Adapter Config │ │ │
|
||||
│ │ └──────────────┘ │ [name]configMap │ │ │
|
||||
│ │ └─────────────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Go实现**:
|
||||
```go
|
||||
type AdapterServiceImpl struct {
|
||||
factories map[string]AdapterFactory
|
||||
configs map[string]*adapterConfig
|
||||
activeAdapters map[string]adapter.DataSourceAdapter
|
||||
}
|
||||
```
|
||||
|
||||
**Python实现**:
|
||||
```python
|
||||
class AdapterService:
|
||||
def __init__(self):
|
||||
self.factories: Dict[str, Callable] = {}
|
||||
self.configs: Dict[str, dict] = {}
|
||||
self.active_adapters: Dict[str, DataSourceAdapter] = {}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 测试管理模块
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ 测试管理模块 │
|
||||
│ ┌─────────────────────────────────────────────────────────┐ │
|
||||
│ │ TestService │ │
|
||||
│ │ (Go / Python) │ │
|
||||
│ │ ┌────────────────┐ ┌──────────────────────┐ │ │
|
||||
│ │ │ API Test │ │ Test Cases (JSON) │ │ │
|
||||
│ │ │ • Build Request│◄───────│ • stock_klines │ │ │
|
||||
│ │ │ • Execute HTTP │ │ • futures_klines │ │ │
|
||||
│ │ │ • Parse Result │ │ • admin_health │ │ │
|
||||
│ │ └───────┬────────┘ └──────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌────────────────┐ ┌──────────────────────┐ │ │
|
||||
│ │ │ WS Test │ │ Test History │ │ │
|
||||
│ │ │ • Dial WS │◄───────│ (In-Memory Cache) │ │ │
|
||||
│ │ │ • Send Msg │ │ • api_tests [] │ │ │
|
||||
│ │ │ • Recv Msg │ │ • ws_tests [] │ │ │
|
||||
│ │ └────────────────┘ └──────────────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 数据流图
|
||||
|
||||
### 3.1 配置热加载流程
|
||||
|
||||
```
|
||||
User/Admin Admin API ConfigService File System
|
||||
│ │ │ │
|
||||
│ 1.修改 config.json │ │ │
|
||||
│────────────────────────────►│ │ │
|
||||
│ │ │ │
|
||||
│ 2.点击"热加载" │ │ │
|
||||
│────────────────────────────►│ │ │
|
||||
│ │ 3.POST /system/reload │ │
|
||||
│ │──────────────────────────►│ │
|
||||
│ │ │ 4.读取 config.json │
|
||||
│ │ │───────────────────────►│
|
||||
│ │ │◄───────────────────────│
|
||||
│ │ │ 5.解析并更新内存配置 │
|
||||
│ │ │ │
|
||||
│ │ │ 6.触发回调函数 │
|
||||
│ │◄──────────────────────────│ │
|
||||
│◄────────────────────────────│ 7.返回成功 │ │
|
||||
│ │ │ │
|
||||
```
|
||||
|
||||
### 3.2 适配器切换流程
|
||||
|
||||
```
|
||||
User Admin API AdapterService Adapter Instance
|
||||
│ │ │ │
|
||||
│ 1.选择适配器并点击"启用" │ │ │
|
||||
│──────────────────────────►│ │ │
|
||||
│ │ 2.POST /adapters/toggle │ │
|
||||
│ │─────────────────────────►│ │
|
||||
│ │ │ 3.创建新实例 │
|
||||
│ │ │─────────────────────►│
|
||||
│ │ │◄─────────────────────│
|
||||
│ │ │ 4.调用 Connect() │
|
||||
│ │ │ │
|
||||
│ │ │ 5.更新 activeAdapters│
|
||||
│ │◄─────────────────────────│ │
|
||||
│◄──────────────────────────│ 6.返回成功 │ │
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键设计决策
|
||||
|
||||
### 4.1 配置存储
|
||||
|
||||
| 方案 | 优点 | 缺点 | Go选择 | Python选择 |
|
||||
|------|------|------|--------|------------|
|
||||
| JSON文件 | 简单、易编辑、无依赖 | 无事务支持 | ✅ | ✅ |
|
||||
| 数据库存储 | 支持事务、历史版本 | 增加依赖 | ❌ | ❌ |
|
||||
| etcd/consul | 分布式、高可用 | 引入新组件 | ❌ | ❌ |
|
||||
|
||||
**Python增强**: 使用 Pydantic 进行类型验证和自动解析
|
||||
|
||||
### 4.2 前端实现
|
||||
|
||||
| 方案 | 优点 | 缺点 | 选择 |
|
||||
|------|------|------|------|
|
||||
| 纯HTML/JS | 无依赖、部署简单 | 功能受限 | ✅ |
|
||||
| Vue/React | 功能强大、生态丰富 | 需构建、体积大 | ❌ |
|
||||
| 独立前端项目 | 前后端分离 | 部署复杂 | ❌ |
|
||||
|
||||
### 4.3 数据库访问
|
||||
|
||||
| 方案 | Go | Python | 说明 |
|
||||
|------|----|--------|------|
|
||||
| 原生SQL | ✅ database/sql | ❌ | Go标准方式 |
|
||||
| ORM | ❌ | ✅ SQLAlchemy | Python生态标准 |
|
||||
| SQL Builder | ⏳ 可选 | ⏳ 可选 | 未来考虑 |
|
||||
|
||||
---
|
||||
|
||||
## 5. 扩展点设计
|
||||
|
||||
### 5.1 新增适配器
|
||||
|
||||
**Go**:
|
||||
```go
|
||||
// 1. 实现适配器接口
|
||||
type MyAdapter struct { ... }
|
||||
|
||||
func (a *MyAdapter) Connect(config map[string]string) error { ... }
|
||||
func (a *MyAdapter) HealthCheck() error { ... }
|
||||
// ... 实现其他方法
|
||||
|
||||
// 2. 注册到服务
|
||||
func init() {
|
||||
adapterService.RegisterAdapter("myadapter", func() adapter.DataSourceAdapter {
|
||||
return &MyAdapter{}
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**Python**:
|
||||
```python
|
||||
# 1. 实现适配器接口
|
||||
class MyAdapter(DataSourceAdapter):
|
||||
async def connect(self, config: dict) -> None:
|
||||
pass
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
return True
|
||||
# ... 实现其他方法
|
||||
|
||||
# 2. 注册到服务
|
||||
adapter_service.register_adapter("myadapter", lambda: MyAdapter())
|
||||
```
|
||||
|
||||
### 5.2 新增配置类型
|
||||
|
||||
**Go**:
|
||||
```go
|
||||
// 1. 扩展配置结构
|
||||
type Config struct {
|
||||
// ... 现有配置
|
||||
Custom CustomConfig `json:"custom"`
|
||||
}
|
||||
|
||||
// 2. 在 ConfigService 中添加处理逻辑
|
||||
```
|
||||
|
||||
**Python**:
|
||||
```python
|
||||
# 1. 扩展配置结构
|
||||
class Config(BaseModel):
|
||||
# ... 现有配置
|
||||
custom: CustomConfig = Field(default_factory=CustomConfig)
|
||||
|
||||
# 2. Pydantic 自动处理验证和解析
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 部署架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 生产环境部署 │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────┐ │
|
||||
│ │ Nginx / LB │ │
|
||||
│ │ (SSL终止、静态资源缓存) │ │
|
||||
│ └──────────────────┬──────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────┴──────────┐ │
|
||||
│ │ │ │
|
||||
│ ┌───────▼───────┐ ┌────────▼────────┐ │
|
||||
│ │ Market Data │ │ Market Data │ │
|
||||
│ │ Service 1 │ │ Service 2 │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ • Go 或 │ │ • Go 或 │ │
|
||||
│ │ Python │ │ Python │ │
|
||||
│ │ • /v1/api │ │ • /v1/api │ │
|
||||
│ │ • /admin │ │ • /admin │ │
|
||||
│ │ • /v1/stream │ │ • /v1/stream │ │
|
||||
│ └───────┬───────┘ └────────┬────────┘ │
|
||||
│ │ │ │
|
||||
│ └──────────┬──────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────▼──────────────────┐ │
|
||||
│ │ PostgreSQL Cluster │ │
|
||||
│ │ │ │
|
||||
│ │ • data_source_config (数据源配置) │ │
|
||||
│ │ • 其他业务表... │ │
|
||||
│ └─────────────────────────────────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 混合部署建议
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 推荐混合部署架构 │
|
||||
│ │
|
||||
│ ┌──────────────────┐ ┌──────────────────┐ │
|
||||
│ │ Go 实现 │ │ Python 实现 │ │
|
||||
│ │ (生产环境) │ │ (开发/测试) │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ • 高并发API │ │ • 数据同步工具 │ │
|
||||
│ │ • WebSocket │ │ • 管理后台 │ │
|
||||
│ │ • 核心服务 │ │ • 快速原型 │ │
|
||||
│ └────────┬─────────┘ └────────┬─────────┘ │
|
||||
│ │ │ │
|
||||
│ └───────────┬───────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────▼────────────┐ │
|
||||
│ │ PostgreSQL │ │
|
||||
│ │ (共享数据库) │ │
|
||||
│ └─────────────────────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 监控点
|
||||
|
||||
| 模块 | 监控项 | Go方式 | Python方式 |
|
||||
|------|--------|--------|------------|
|
||||
| ConfigService | 配置加载耗时 | 日志 | 日志 |
|
||||
| ConfigService | 配置热加载次数 | 日志 | 日志 |
|
||||
| AdapterService | 适配器健康状态 | 接口/日志 | 接口/日志 |
|
||||
| AdapterService | 适配器切换次数 | 日志 | 日志 |
|
||||
| TestService | 测试执行次数 | 内存/日志 | 内存/日志 |
|
||||
| TestService | 测试成功率 | 计算 | 计算 |
|
||||
| WebSocket | 连接数 | metrics | metrics |
|
||||
| Database | 查询耗时 | SQL日志 | SQLAlchemy事件 |
|
||||
|
||||
---
|
||||
|
||||
## 8. 实现差异汇总
|
||||
|
||||
| 方面 | Go | Python |
|
||||
|------|----|--------|
|
||||
| 并发模型 | Goroutines | asyncio |
|
||||
| 锁机制 | sync.Mutex | threading.Lock/asyncio.Lock |
|
||||
| 数据库 | database/sql + pq | SQLAlchemy + psycopg2 |
|
||||
| Web框架 | Gin | FastAPI |
|
||||
| 类型系统 | Struct + Interface | Pydantic Model |
|
||||
| 配置解析 | encoding/json | Pydantic Settings |
|
||||
| 错误处理 | error interface | Exception |
|
||||
| 日志 | log | logging |
|
||||
|
||||
---
|
||||
|
||||
**文档结束**
|
||||
@ -0,0 +1,37 @@
|
||||
module market-data-service
|
||||
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.1
|
||||
github.com/lib/pq v1.10.9
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.9.0 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
@ -0,0 +1,236 @@
|
||||
// Package handler 管理后台Handler实现
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"market-data-service/api"
|
||||
"market-data-service/internal/service"
|
||||
)
|
||||
|
||||
// AdminHandlerImpl 管理后台Handler实现
|
||||
type AdminHandlerImpl struct {
|
||||
configService service.ConfigService
|
||||
adapterService service.AdapterService
|
||||
testService service.TestService
|
||||
}
|
||||
|
||||
// NewAdminHandlerImpl 创建管理后台Handler
|
||||
func NewAdminHandlerImpl(
|
||||
configService service.ConfigService,
|
||||
adapterService service.AdapterService,
|
||||
testService service.TestService,
|
||||
) *AdminHandlerImpl {
|
||||
return &AdminHandlerImpl{
|
||||
configService: configService,
|
||||
adapterService: adapterService,
|
||||
testService: testService,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure interfaces are implemented
|
||||
var _ api.ConfigHandler = (*AdminHandlerImpl)(nil)
|
||||
var _ api.AdapterHandler = (*AdminHandlerImpl)(nil)
|
||||
var _ api.TestHandler = (*AdminHandlerImpl)(nil)
|
||||
|
||||
// ============================================
|
||||
// 配置管理接口实现
|
||||
// ============================================
|
||||
|
||||
// GetConfigList 获取配置列表
|
||||
func (h *AdminHandlerImpl) GetConfigList(ctx context.Context, req *api.ConfigListRequest) (*api.Response, error) {
|
||||
data, err := h.configService.GetConfigList(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateConfig 更新配置
|
||||
func (h *AdminHandlerImpl) UpdateConfig(ctx context.Context, req *api.ConfigUpdateRequest) (*api.Response, error) {
|
||||
data, err := h.configService.UpdateConfig(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReloadConfig 热加载配置
|
||||
func (h *AdminHandlerImpl) ReloadConfig(ctx context.Context, req *api.ReloadRequest) (*api.Response, error) {
|
||||
data, err := h.configService.ReloadConfig(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSystemStatus 获取系统状态
|
||||
func (h *AdminHandlerImpl) GetSystemStatus(ctx context.Context) (*api.Response, error) {
|
||||
data, err := h.configService.GetSystemStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 适配器管理接口实现
|
||||
// ============================================
|
||||
|
||||
// GetAdapterList 获取适配器列表
|
||||
func (h *AdminHandlerImpl) GetAdapterList(ctx context.Context) (*api.Response, error) {
|
||||
data, err := h.adapterService.GetAdapterList(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToggleAdapter 启用/禁用适配器
|
||||
func (h *AdminHandlerImpl) ToggleAdapter(ctx context.Context, req *api.AdapterToggleRequest) (*api.Response, error) {
|
||||
if err := h.adapterService.ToggleAdapter(ctx, req); err != nil {
|
||||
return &api.Response{
|
||||
Code: 500,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateAdapterConfig 更新适配器配置
|
||||
func (h *AdminHandlerImpl) UpdateAdapterConfig(ctx context.Context, req *api.AdapterConfigUpdateRequest) (*api.Response, error) {
|
||||
if err := h.adapterService.UpdateAdapterConfig(ctx, req); err != nil {
|
||||
return &api.Response{
|
||||
Code: 500,
|
||||
Message: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 测试管理接口实现
|
||||
// ============================================
|
||||
|
||||
// GetAPITestList 获取API测试列表
|
||||
func (h *AdminHandlerImpl) GetAPITestList(ctx context.Context) (*api.Response, error) {
|
||||
data, err := h.testService.GetAPITestList(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunAPITest 执行API测试
|
||||
func (h *AdminHandlerImpl) RunAPITest(ctx context.Context, req *api.APITestRequest) (*api.Response, error) {
|
||||
// 获取基础URL
|
||||
baseURL := "http://localhost:8080"
|
||||
if cfg := h.configService.GetCurrentConfig(); cfg != nil && cfg.Server.Port != 0 {
|
||||
baseURL = fmt.Sprintf("http://localhost:%d", cfg.Server.Port)
|
||||
}
|
||||
|
||||
data, err := h.testService.RunAPITest(ctx, baseURL, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetWSTestList 获取WebSocket测试列表
|
||||
func (h *AdminHandlerImpl) GetWSTestList(ctx context.Context) (*api.Response, error) {
|
||||
data, err := h.testService.GetWSTestList(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置WebSocket URL
|
||||
wsURL := "ws://localhost:8080/v1/stream"
|
||||
if cfg := h.configService.GetCurrentConfig(); cfg != nil && cfg.Server.Port != 0 {
|
||||
wsURL = "ws://localhost:" + string(rune(cfg.Server.Port)) + "/v1/stream"
|
||||
}
|
||||
data.WSURL = wsURL
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunWSTest 执行WebSocket测试
|
||||
func (h *AdminHandlerImpl) RunWSTest(ctx context.Context, req *api.WSTestRequest) (*api.Response, error) {
|
||||
// 获取WebSocket URL
|
||||
wsURL := "ws://localhost:8080/v1/stream"
|
||||
if cfg := h.configService.GetCurrentConfig(); cfg != nil && cfg.Server.Port != 0 {
|
||||
wsURL = fmt.Sprintf("ws://localhost:%d/v1/stream", cfg.Server.Port)
|
||||
}
|
||||
|
||||
data, err := h.testService.RunWSTest(ctx, wsURL, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTestHistory 获取测试历史
|
||||
func (h *AdminHandlerImpl) GetTestHistory(ctx context.Context, req *api.TestHistoryRequest) (*api.Response, error) {
|
||||
data, err := h.testService.GetTestHistory(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
@ -0,0 +1,22 @@
|
||||
package model
|
||||
|
||||
// Model 领域模型定义
|
||||
|
||||
// Symbol 标的模型
|
||||
type Symbol struct {
|
||||
SymbolID string
|
||||
Name string
|
||||
// TODO: 补充字段
|
||||
}
|
||||
|
||||
// KLine K线模型
|
||||
type KLine struct {
|
||||
Symbol string
|
||||
Open float64
|
||||
High float64
|
||||
Low float64
|
||||
Close float64
|
||||
Volume int64
|
||||
Amount float64
|
||||
// TODO: 补充字段
|
||||
}
|
||||
@ -0,0 +1,334 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"market-data-service/api"
|
||||
)
|
||||
|
||||
// FuturesRepository 期货数据仓库
|
||||
type FuturesRepository struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewFuturesRepository 创建期货Repository
|
||||
func NewFuturesRepository(db *DB) *FuturesRepository {
|
||||
return &FuturesRepository{db: db}
|
||||
}
|
||||
|
||||
// GetKLines 获取K线数据
|
||||
func (r *FuturesRepository) GetKLines(ctx context.Context, symbol string, freq api.Frequency, start, end time.Time) ([]api.KLineItem, error) {
|
||||
tableName := fmt.Sprintf("futures.klines_%s", freq)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT ts, open, high, low, close, volume, amount, open_interest
|
||||
FROM %s
|
||||
WHERE symbol_id = $1 AND ts >= $2 AND ts <= $3
|
||||
ORDER BY ts ASC
|
||||
`, tableName)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, symbol, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []api.KLineItem
|
||||
for rows.Next() {
|
||||
var item api.KLineItem
|
||||
var oi sql.NullInt64
|
||||
if err := rows.Scan(
|
||||
&item.Time, &item.Open, &item.High, &item.Low, &item.Close, &item.Volume, &item.Amount, &oi); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oi.Valid {
|
||||
item.OpenInterest = &oi.Int64
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
// SaveKLines 保存K线数据
|
||||
func (r *FuturesRepository) SaveKLines(ctx context.Context, freq api.Frequency, symbol string, items []api.KLineItem) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableName := fmt.Sprintf("futures.klines_%s", freq)
|
||||
|
||||
// 使用批量插入
|
||||
valueStrs := make([]string, 0, len(items))
|
||||
args := make([]interface{}, 0, len(items)*8)
|
||||
argIdx := 1
|
||||
|
||||
for _, item := range items {
|
||||
valueStrs = append(valueStrs, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)",
|
||||
argIdx, argIdx+1, argIdx+2, argIdx+3, argIdx+4, argIdx+5, argIdx+6, argIdx+7))
|
||||
args = append(args, symbol, item.Time, item.Open, item.High, item.Low, item.Close, item.Volume, item.Amount)
|
||||
if item.OpenInterest != nil {
|
||||
args = append(args, *item.OpenInterest)
|
||||
} else {
|
||||
args = append(args, nil)
|
||||
}
|
||||
argIdx += 8
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %s (symbol_id, ts, open, high, low, close, volume, amount, open_interest)
|
||||
VALUES %s
|
||||
ON CONFLICT (symbol_id, ts) DO UPDATE SET
|
||||
open = EXCLUDED.open,
|
||||
high = EXCLUDED.high,
|
||||
low = EXCLUDED.low,
|
||||
close = EXCLUDED.close,
|
||||
volume = EXCLUDED.volume,
|
||||
amount = EXCLUDED.amount,
|
||||
open_interest = EXCLUDED.open_interest
|
||||
`, tableName, strings.Join(valueStrs, ","))
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListSymbols 查询标的列表
|
||||
func (r *FuturesRepository) ListSymbols(ctx context.Context, req *api.SymbolListRequest) ([]api.Symbol, int, error) {
|
||||
whereClause := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if req.Exchange != "" {
|
||||
whereClause += fmt.Sprintf(" AND exchange = $%d", argIdx)
|
||||
args = append(args, req.Exchange)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.Underlying != "" {
|
||||
whereClause += fmt.Sprintf(" AND underlying = $%d", argIdx)
|
||||
args = append(args, req.Underlying)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.Keyword != "" {
|
||||
whereClause += fmt.Sprintf(" AND (symbol_id ILIKE $%d OR name ILIKE $%d)", argIdx, argIdx)
|
||||
args = append(args, "%"+req.Keyword+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
// 查询总数
|
||||
countQuery := "SELECT COUNT(*) FROM futures.symbols " + whereClause
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 查询数据
|
||||
query := fmt.Sprintf(`
|
||||
SELECT symbol_id, symbol_type, exchange, name, underlying, contract_month, list_date, delist_date, status
|
||||
FROM futures.symbols
|
||||
%s
|
||||
ORDER BY symbol_id
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argIdx, argIdx+1)
|
||||
args = append(args, req.Size, (req.Page-1)*req.Size)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var symbols []api.Symbol
|
||||
for rows.Next() {
|
||||
var s api.Symbol
|
||||
var listDate, delistDate sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&s.SymbolID, &s.SymbolType, &s.Exchange, &s.Name, &s.Underlying,
|
||||
&s.ContractMonth, &listDate, &delistDate, &s.Status); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if listDate.Valid {
|
||||
s.ListDate = &listDate.Time
|
||||
}
|
||||
if delistDate.Valid {
|
||||
s.DelistDate = &delistDate.Time
|
||||
}
|
||||
symbols = append(symbols, s)
|
||||
}
|
||||
|
||||
return symbols, total, rows.Err()
|
||||
}
|
||||
|
||||
// GetContractsByUnderlying 根据品种获取合约
|
||||
func (r *FuturesRepository) GetContractsByUnderlying(ctx context.Context, underlying string, exchange string) (*api.FuturesContractsData, error) {
|
||||
query := `
|
||||
SELECT symbol_id, symbol_type, exchange, name, underlying, contract_month, list_date, delist_date, status
|
||||
FROM futures.symbols
|
||||
WHERE underlying = $1 AND status = 'active'
|
||||
`
|
||||
args := []interface{}{underlying}
|
||||
|
||||
if exchange != "" {
|
||||
query += " AND exchange = $2"
|
||||
args = append(args, exchange)
|
||||
}
|
||||
|
||||
query += " ORDER BY contract_month ASC"
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var contracts []api.FuturesContractInfo
|
||||
for rows.Next() {
|
||||
var c api.FuturesContractInfo
|
||||
var listDate, delistDate sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&c.SymbolID, &c.SymbolType, &c.Exchange, &c.Name, &c.Underlying,
|
||||
&c.ContractMonth, &listDate, &delistDate, &c.Status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if listDate.Valid {
|
||||
c.ListDate = &listDate.Time
|
||||
}
|
||||
if delistDate.Valid {
|
||||
c.DelistDate = &delistDate.Time
|
||||
}
|
||||
contracts = append(contracts, c)
|
||||
}
|
||||
|
||||
return &api.FuturesContractsData{
|
||||
Underlying: underlying,
|
||||
Count: len(contracts),
|
||||
Items: contracts,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
||||
// GetTradingDates 获取交易日历
|
||||
func (r *FuturesRepository) GetTradingDates(ctx context.Context, start, end string) (*api.TradingDatesData, error) {
|
||||
query := `
|
||||
SELECT trade_date
|
||||
FROM futures.trading_calendar
|
||||
WHERE trade_date >= $1 AND trade_date <= $2 AND is_trading_day = true
|
||||
ORDER BY trade_date ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dates []string
|
||||
for rows.Next() {
|
||||
var date string
|
||||
if err := rows.Scan(&date); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dates = append(dates, date)
|
||||
}
|
||||
|
||||
// 计算总天数
|
||||
startDate, _ := time.Parse("20060102", start)
|
||||
endDate, _ := time.Parse("20060102", end)
|
||||
totalDays := int(endDate.Sub(startDate).Hours()/24) + 1
|
||||
|
||||
return &api.TradingDatesData{
|
||||
Start: start,
|
||||
End: end,
|
||||
TotalDays: totalDays,
|
||||
TradingDays: len(dates),
|
||||
TradingDates: dates,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
||||
// SaveSymbols 保存标的列表
|
||||
func (r *FuturesRepository) SaveSymbols(ctx context.Context, symbols []api.Symbol) error {
|
||||
if len(symbols) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO futures.symbols (symbol_id, symbol_type, exchange, name, underlying, contract_month, list_date, delist_date, status)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (symbol_id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
underlying = EXCLUDED.underlying,
|
||||
contract_month = EXCLUDED.contract_month,
|
||||
list_date = EXCLUDED.list_date,
|
||||
delist_date = EXCLUDED.delist_date,
|
||||
status = EXCLUDED.status,
|
||||
updated_at = NOW()
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, s := range symbols {
|
||||
var listDate, delistDate interface{}
|
||||
if s.ListDate != nil {
|
||||
listDate = *s.ListDate
|
||||
}
|
||||
if s.DelistDate != nil {
|
||||
delistDate = *s.DelistDate
|
||||
}
|
||||
|
||||
_, err := stmt.ExecContext(ctx, s.SymbolID, s.SymbolType, s.Exchange, s.Name, s.Underlying,
|
||||
s.ContractMonth, listDate, delistDate, s.Status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// SaveTradingCalendar 保存交易日历
|
||||
func (r *FuturesRepository) SaveTradingCalendar(ctx context.Context, dates []api.TradeCalData) error {
|
||||
if len(dates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO futures.trading_calendar (trade_date, is_trading_day, has_night_session, week_day)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (trade_date) DO UPDATE SET
|
||||
is_trading_day = EXCLUDED.is_trading_day,
|
||||
has_night_session = EXCLUDED.has_night_session,
|
||||
week_day = EXCLUDED.week_day,
|
||||
updated_at = NOW()
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, d := range dates {
|
||||
_, err := stmt.ExecContext(ctx, d.Date.Format("2006-01-02"), d.IsTradingDay, d.HasNightSession, int(d.Date.Weekday())+1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
@ -0,0 +1,13 @@
|
||||
package repository
|
||||
|
||||
// Repository 数据访问层接口定义
|
||||
|
||||
// StockRepository 股票数据仓库
|
||||
type StockRepository interface {
|
||||
// TODO: 定义股票相关数据访问方法
|
||||
}
|
||||
|
||||
// FuturesRepository 期货数据仓库
|
||||
type FuturesRepository interface {
|
||||
// TODO: 定义期货相关数据访问方法
|
||||
}
|
||||
@ -0,0 +1,291 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"market-data-service/api"
|
||||
)
|
||||
|
||||
// DB PostgreSQL连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
func NewDB(connStr string) (*DB, error) {
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &DB{db}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 股票Repository
|
||||
// ============================================
|
||||
|
||||
// StockRepository 股票数据仓库
|
||||
type StockRepository struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewStockRepository 创建股票Repository
|
||||
func NewStockRepository(db *DB) *StockRepository {
|
||||
return &StockRepository{db: db}
|
||||
}
|
||||
|
||||
// GetKLines 获取K线数据
|
||||
func (r *StockRepository) GetKLines(ctx context.Context, symbol string, freq api.Frequency, start, end time.Time, adjust api.AdjustType) ([]api.KLineItem, error) {
|
||||
tableName := fmt.Sprintf("stock.klines_%s", freq)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT ts, open, high, low, close, volume, amount
|
||||
FROM %s
|
||||
WHERE symbol_id = $1 AND ts >= $2 AND ts <= $3
|
||||
ORDER BY ts ASC
|
||||
`, tableName)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, symbol, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []api.KLineItem
|
||||
for rows.Next() {
|
||||
var item api.KLineItem
|
||||
if err := rows.Scan(&item.Time, &item.Open, &item.High, &item.Low, &item.Close, &item.Volume, &item.Amount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, rows.Err()
|
||||
}
|
||||
|
||||
// SaveKLines 保存K线数据
|
||||
func (r *StockRepository) SaveKLines(ctx context.Context, freq api.Frequency, items []api.KLineItem) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableName := fmt.Sprintf("stock.klines_%s", freq)
|
||||
|
||||
// 使用批量插入
|
||||
valueStrs := make([]string, 0, len(items))
|
||||
args := make([]interface{}, 0, len(items)*7)
|
||||
argIdx := 1
|
||||
|
||||
for _, item := range items {
|
||||
valueStrs = append(valueStrs, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d)",
|
||||
argIdx, argIdx+1, argIdx+2, argIdx+3, argIdx+4, argIdx+5, argIdx+6))
|
||||
args = append(args, item.Symbol, item.Time, item.Open, item.High, item.Low, item.Close, item.Volume, item.Amount)
|
||||
argIdx += 7
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %s (symbol_id, ts, open, high, low, close, volume, amount)
|
||||
VALUES %s
|
||||
ON CONFLICT (symbol_id, ts) DO UPDATE SET
|
||||
open = EXCLUDED.open,
|
||||
high = EXCLUDED.high,
|
||||
low = EXCLUDED.low,
|
||||
close = EXCLUDED.close,
|
||||
volume = EXCLUDED.volume,
|
||||
amount = EXCLUDED.amount
|
||||
`, tableName, strings.Join(valueStrs, ","))
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListSymbols 查询标的列表
|
||||
func (r *StockRepository) ListSymbols(ctx context.Context, req *api.SymbolListRequest) ([]api.Symbol, int, error) {
|
||||
whereClause := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if req.Exchange != "" {
|
||||
whereClause += fmt.Sprintf(" AND exchange = $%d", argIdx)
|
||||
args = append(args, req.Exchange)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.Keyword != "" {
|
||||
whereClause += fmt.Sprintf(" AND (symbol_id ILIKE $%d OR name ILIKE $%d)", argIdx, argIdx)
|
||||
args = append(args, "%"+req.Keyword+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
// 查询总数
|
||||
countQuery := "SELECT COUNT(*) FROM stock.symbols " + whereClause
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 查询数据
|
||||
query := fmt.Sprintf(`
|
||||
SELECT symbol_id, symbol_type, exchange, name, name_en, list_date, delist_date, industry, status
|
||||
FROM stock.symbols
|
||||
%s
|
||||
ORDER BY symbol_id
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argIdx, argIdx+1)
|
||||
args = append(args, req.Size, (req.Page-1)*req.Size)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var symbols []api.Symbol
|
||||
for rows.Next() {
|
||||
var s api.Symbol
|
||||
var listDate, delistDate sql.NullTime
|
||||
if err := rows.Scan(&s.SymbolID, &s.SymbolType, &s.Exchange, &s.Name, &s.NameEN,
|
||||
&listDate, &delistDate, &s.Industry, &s.Status); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if listDate.Valid {
|
||||
s.ListDate = &listDate.Time
|
||||
}
|
||||
if delistDate.Valid {
|
||||
s.DelistDate = &delistDate.Time
|
||||
}
|
||||
symbols = append(symbols, s)
|
||||
}
|
||||
|
||||
return symbols, total, rows.Err()
|
||||
}
|
||||
|
||||
// GetTradingDates 获取交易日历
|
||||
func (r *StockRepository) GetTradingDates(ctx context.Context, start, end string) (*api.TradingDatesData, error) {
|
||||
query := `
|
||||
SELECT trade_date
|
||||
FROM stock.trading_calendar
|
||||
WHERE trade_date >= $1 AND trade_date <= $2 AND is_trading_day = true
|
||||
ORDER BY trade_date ASC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var dates []string
|
||||
for rows.Next() {
|
||||
var date string
|
||||
if err := rows.Scan(&date); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dates = append(dates, date)
|
||||
}
|
||||
|
||||
// 计算总天数
|
||||
startDate, _ := time.Parse("20060102", start)
|
||||
endDate, _ := time.Parse("20060102", end)
|
||||
totalDays := int(endDate.Sub(startDate).Hours()/24) + 1
|
||||
|
||||
return &api.TradingDatesData{
|
||||
Start: start,
|
||||
End: end,
|
||||
TotalDays: totalDays,
|
||||
TradingDays: len(dates),
|
||||
TradingDates: dates,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
||||
// SaveSymbols 保存标的列表
|
||||
func (r *StockRepository) SaveSymbols(ctx context.Context, symbols []api.Symbol) error {
|
||||
if len(symbols) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO stock.symbols (symbol_id, symbol_type, exchange, name, name_en, list_date, delist_date, industry, status)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (symbol_id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
name_en = EXCLUDED.name_en,
|
||||
list_date = EXCLUDED.list_date,
|
||||
delist_date = EXCLUDED.delist_date,
|
||||
industry = EXCLUDED.industry,
|
||||
status = EXCLUDED.status,
|
||||
updated_at = NOW()
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, s := range symbols {
|
||||
var listDate, delistDate interface{}
|
||||
if s.ListDate != nil {
|
||||
listDate = *s.ListDate
|
||||
}
|
||||
if s.DelistDate != nil {
|
||||
delistDate = *s.DelistDate
|
||||
}
|
||||
|
||||
_, err := stmt.ExecContext(ctx, s.SymbolID, s.SymbolType, s.Exchange, s.Name, s.NameEN,
|
||||
listDate, delistDate, s.Industry, s.Status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// SaveTradingCalendar 保存交易日历
|
||||
func (r *StockRepository) SaveTradingCalendar(ctx context.Context, dates []api.TradeCalData) error {
|
||||
if len(dates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO stock.trading_calendar (trade_date, is_trading_day, week_day)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (trade_date) DO UPDATE SET
|
||||
is_trading_day = EXCLUDED.is_trading_day,
|
||||
week_day = EXCLUDED.week_day,
|
||||
updated_at = NOW()
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, d := range dates {
|
||||
_, err := stmt.ExecContext(ctx, d.Date.Format("2006-01-02"), d.IsTradingDay, int(d.Date.Weekday())+1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
@ -0,0 +1,305 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"market-data-service/adapter"
|
||||
"market-data-service/adapter/tushare"
|
||||
"market-data-service/api"
|
||||
)
|
||||
|
||||
// AdapterService 适配器管理服务接口
|
||||
type AdapterService interface {
|
||||
// GetAdapterList 获取适配器列表
|
||||
GetAdapterList(ctx context.Context) (*api.AdapterListData, error)
|
||||
|
||||
// ToggleAdapter 启用/禁用适配器
|
||||
ToggleAdapter(ctx context.Context, req *api.AdapterToggleRequest) error
|
||||
|
||||
// UpdateAdapterConfig 更新适配器配置
|
||||
UpdateAdapterConfig(ctx context.Context, req *api.AdapterConfigUpdateRequest) error
|
||||
|
||||
// GetActiveAdapter 获取当前激活的适配器
|
||||
GetActiveAdapter(assetClass string) (adapter.DataSourceAdapter, error)
|
||||
|
||||
// GetAvailableAdapters 获取所有可用的适配器名称
|
||||
GetAvailableAdapters() []string
|
||||
|
||||
// RegisterAdapter 注册适配器
|
||||
RegisterAdapter(name string, factory AdapterFactory)
|
||||
}
|
||||
|
||||
// AdapterFactory 适配器工厂函数
|
||||
type AdapterFactory func() adapter.DataSourceAdapter
|
||||
|
||||
// AdapterServiceImpl 适配器服务实现
|
||||
type AdapterServiceImpl struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// 已注册的适配器工厂
|
||||
factories map[string]AdapterFactory
|
||||
|
||||
// 适配器配置
|
||||
configs map[string]*adapterConfig
|
||||
|
||||
// 当前激活的适配器实例
|
||||
activeAdapters map[string]adapter.DataSourceAdapter
|
||||
|
||||
// 适配器元数据
|
||||
metadata map[string]*adapterMetadata
|
||||
}
|
||||
|
||||
// adapterConfig 适配器配置
|
||||
type adapterConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Config map[string]string `json:"config"`
|
||||
}
|
||||
|
||||
// adapterMetadata 适配器元数据
|
||||
type adapterMetadata struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
Description string `json:"description"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// NewAdapterService 创建适配器服务
|
||||
func NewAdapterService() AdapterService {
|
||||
service := &AdapterServiceImpl{
|
||||
factories: make(map[string]AdapterFactory),
|
||||
configs: make(map[string]*adapterConfig),
|
||||
activeAdapters: make(map[string]adapter.DataSourceAdapter),
|
||||
metadata: make(map[string]*adapterMetadata),
|
||||
}
|
||||
|
||||
// 注册内置适配器
|
||||
service.registerBuiltinAdapters()
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// registerBuiltinAdapters 注册内置适配器
|
||||
func (s *AdapterServiceImpl) registerBuiltinAdapters() {
|
||||
// 注册Tushare适配器
|
||||
s.RegisterAdapter("tushare", func() adapter.DataSourceAdapter {
|
||||
return tushare.NewAdapter()
|
||||
})
|
||||
|
||||
// 设置Tushare元数据
|
||||
s.metadata["tushare"] = &adapterMetadata{
|
||||
Name: "tushare",
|
||||
Type: "http",
|
||||
Version: "1.0.0",
|
||||
Description: "Tushare Pro 金融数据接口",
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
s.configs["tushare"] = &adapterConfig{
|
||||
Enabled: true,
|
||||
Config: map[string]string{
|
||||
"token": "",
|
||||
"base_url": "https://api.tushare.pro",
|
||||
},
|
||||
}
|
||||
|
||||
// 预留Wind适配器
|
||||
s.metadata["wind"] = &adapterMetadata{
|
||||
Name: "wind",
|
||||
Type: "ws",
|
||||
Version: "1.0.0",
|
||||
Description: "Wind 金融终端接口(预留)",
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
s.configs["wind"] = &adapterConfig{
|
||||
Enabled: false,
|
||||
Config: map[string]string{
|
||||
"host": "localhost",
|
||||
"port": "8081",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetAdapterList 获取适配器列表
|
||||
func (s *AdapterServiceImpl) GetAdapterList(ctx context.Context) (*api.AdapterListData, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
adapters := make([]api.AdapterInfo, 0, len(s.metadata))
|
||||
|
||||
for name, meta := range s.metadata {
|
||||
cfg, ok := s.configs[name]
|
||||
if !ok {
|
||||
cfg = &adapterConfig{Enabled: false, Config: make(map[string]string)}
|
||||
}
|
||||
|
||||
status := api.AdapterStatusDisabled
|
||||
if cfg.Enabled {
|
||||
status = api.AdapterStatusStandby
|
||||
// 检查是否是激活状态
|
||||
if _, active := s.activeAdapters[name]; active {
|
||||
status = api.AdapterStatusActive
|
||||
}
|
||||
}
|
||||
|
||||
adapters = append(adapters, api.AdapterInfo{
|
||||
Name: meta.Name,
|
||||
Type: meta.Type,
|
||||
Version: meta.Version,
|
||||
Description: meta.Description,
|
||||
Status: status,
|
||||
Config: cfg.Config,
|
||||
UpdatedAt: meta.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
return &api.AdapterListData{
|
||||
Adapters: adapters,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToggleAdapter 启用/禁用适配器
|
||||
func (s *AdapterServiceImpl) ToggleAdapter(ctx context.Context, req *api.AdapterToggleRequest) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
cfg, ok := s.configs[req.Name]
|
||||
if !ok {
|
||||
return fmt.Errorf("adapter not found: %s", req.Name)
|
||||
}
|
||||
|
||||
cfg.Enabled = req.Enable
|
||||
|
||||
// 如果禁用,关闭适配器连接
|
||||
if !req.Enable {
|
||||
if adapter, ok := s.activeAdapters[req.Name]; ok {
|
||||
adapter.Close()
|
||||
delete(s.activeAdapters, req.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新元数据
|
||||
if meta, ok := s.metadata[req.Name]; ok {
|
||||
meta.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAdapterConfig 更新适配器配置
|
||||
func (s *AdapterServiceImpl) UpdateAdapterConfig(ctx context.Context, req *api.AdapterConfigUpdateRequest) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
cfg, ok := s.configs[req.Name]
|
||||
if !ok {
|
||||
return fmt.Errorf("adapter not found: %s", req.Name)
|
||||
}
|
||||
|
||||
// 更新配置
|
||||
for k, v := range req.Config {
|
||||
cfg.Config[k] = v
|
||||
}
|
||||
|
||||
// 如果适配器已激活,重新连接
|
||||
if adapter, ok := s.activeAdapters[req.Name]; ok {
|
||||
adapter.Close()
|
||||
delete(s.activeAdapters, req.Name)
|
||||
|
||||
// 如果启用状态,重新连接
|
||||
if cfg.Enabled {
|
||||
if err := s.connectAdapter(req.Name); err != nil {
|
||||
return fmt.Errorf("failed to reconnect adapter: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新元数据
|
||||
if meta, ok := s.metadata[req.Name]; ok {
|
||||
meta.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveAdapter 获取当前激活的适配器
|
||||
func (s *AdapterServiceImpl) GetActiveAdapter(assetClass string) (adapter.DataSourceAdapter, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// 根据资产类别获取配置
|
||||
// 这里简化处理,实际应该从配置服务获取
|
||||
adapterName := "tushare"
|
||||
if assetClass == "futures" {
|
||||
adapterName = "tushare"
|
||||
}
|
||||
|
||||
// 检查是否已有激活的实例
|
||||
if adapter, ok := s.activeAdapters[adapterName]; ok {
|
||||
return adapter, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no active adapter for %s", assetClass)
|
||||
}
|
||||
|
||||
// GetAvailableAdapters 获取所有可用的适配器名称
|
||||
func (s *AdapterServiceImpl) GetAvailableAdapters() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(s.metadata))
|
||||
for name, meta := range s.metadata {
|
||||
// 只返回有工厂的适配器(已实现的)
|
||||
if _, ok := s.factories[name]; ok {
|
||||
names = append(names, fmt.Sprintf("%s|%s", name, meta.Description))
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// RegisterAdapter 注册适配器
|
||||
func (s *AdapterServiceImpl) RegisterAdapter(name string, factory AdapterFactory) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.factories[name] = factory
|
||||
}
|
||||
|
||||
// connectAdapter 连接适配器
|
||||
func (s *AdapterServiceImpl) connectAdapter(name string) error {
|
||||
factory, ok := s.factories[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("adapter factory not found: %s", name)
|
||||
}
|
||||
|
||||
cfg, ok := s.configs[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("adapter config not found: %s", name)
|
||||
}
|
||||
|
||||
adapter := factory()
|
||||
if err := adapter.Connect(cfg.Config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.activeAdapters[name] = adapter
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 适配器健康检查
|
||||
func (s *AdapterServiceImpl) HealthCheck(name string) error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
adapter, ok := s.activeAdapters[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("adapter not active: %s", name)
|
||||
}
|
||||
|
||||
return adapter.HealthCheck()
|
||||
}
|
||||
@ -0,0 +1,477 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"market-data-service/api"
|
||||
"market-data-service/pkg/config"
|
||||
)
|
||||
|
||||
// ConfigService 配置管理服务接口
|
||||
type ConfigService interface {
|
||||
// GetConfigList 获取配置列表
|
||||
GetConfigList(ctx context.Context, req *api.ConfigListRequest) (*api.ConfigListData, error)
|
||||
|
||||
// UpdateConfig 更新配置
|
||||
UpdateConfig(ctx context.Context, req *api.ConfigUpdateRequest) (*api.ConfigUpdateData, error)
|
||||
|
||||
// ReloadConfig 热加载配置
|
||||
ReloadConfig(ctx context.Context, req *api.ReloadRequest) (*api.ReloadData, error)
|
||||
|
||||
// GetSystemStatus 获取系统状态
|
||||
GetSystemStatus(ctx context.Context) (*api.SystemStatusData, error)
|
||||
|
||||
// GetCurrentConfig 获取当前配置(内部使用)
|
||||
GetCurrentConfig() *config.Config
|
||||
}
|
||||
|
||||
// ConfigServiceImpl 配置服务实现
|
||||
type ConfigServiceImpl struct {
|
||||
configPath string
|
||||
config *config.Config
|
||||
mu sync.RWMutex
|
||||
|
||||
// 配置变更回调
|
||||
callbacks map[api.ConfigType][]func()
|
||||
cbMu sync.RWMutex
|
||||
|
||||
// 启动时间
|
||||
startTime time.Time
|
||||
|
||||
// 配置版本
|
||||
version string
|
||||
}
|
||||
|
||||
// NewConfigService 创建配置服务
|
||||
func NewConfigService(configPath string) (ConfigService, error) {
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
// 如果加载失败,使用默认配置
|
||||
cfg = getDefaultConfig()
|
||||
}
|
||||
|
||||
return &ConfigServiceImpl{
|
||||
configPath: configPath,
|
||||
config: cfg,
|
||||
callbacks: make(map[api.ConfigType][]func()),
|
||||
startTime: time.Now(),
|
||||
version: "1.0.0",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetConfigList 获取配置列表
|
||||
func (s *ConfigServiceImpl) GetConfigList(ctx context.Context, req *api.ConfigListRequest) (*api.ConfigListData, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
sections := []api.ConfigSection{}
|
||||
|
||||
// 服务器配置
|
||||
if req.Type == "" || req.Type == api.ConfigTypeServer {
|
||||
sections = append(sections, api.ConfigSection{
|
||||
Name: "服务器配置",
|
||||
Type: api.ConfigTypeServer,
|
||||
Description: "HTTP服务器相关配置",
|
||||
Items: []api.ConfigItem{
|
||||
{
|
||||
Key: "port",
|
||||
Value: s.config.Server.Port,
|
||||
Type: "int",
|
||||
Description: "服务端口",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Key: "mode",
|
||||
Value: s.config.Server.Mode,
|
||||
Type: "string",
|
||||
Description: "运行模式: debug/release",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Key: "api_key",
|
||||
Value: s.config.Server.APIKey,
|
||||
Type: "string",
|
||||
Description: "API认证密钥",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 数据库配置
|
||||
if req.Type == "" || req.Type == api.ConfigTypeDatabase {
|
||||
sections = append(sections, api.ConfigSection{
|
||||
Name: "数据库配置",
|
||||
Type: api.ConfigTypeDatabase,
|
||||
Description: "PostgreSQL数据库连接配置",
|
||||
Items: []api.ConfigItem{
|
||||
{
|
||||
Key: "host",
|
||||
Value: s.config.Database.Host,
|
||||
Type: "string",
|
||||
Description: "数据库主机地址",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Key: "port",
|
||||
Value: s.config.Database.Port,
|
||||
Type: "int",
|
||||
Description: "数据库端口",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Key: "user",
|
||||
Value: s.config.Database.User,
|
||||
Type: "string",
|
||||
Description: "数据库用户名",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Key: "password",
|
||||
Value: "********",
|
||||
Type: "password",
|
||||
Description: "数据库密码",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Key: "database",
|
||||
Value: s.config.Database.Database,
|
||||
Type: "string",
|
||||
Description: "数据库名",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Redis配置
|
||||
if req.Type == "" || req.Type == api.ConfigTypeRedis {
|
||||
sections = append(sections, api.ConfigSection{
|
||||
Name: "Redis配置",
|
||||
Type: api.ConfigTypeRedis,
|
||||
Description: "Redis缓存配置",
|
||||
Items: []api.ConfigItem{
|
||||
{
|
||||
Key: "host",
|
||||
Value: s.config.Redis.Host,
|
||||
Type: "string",
|
||||
Description: "Redis主机地址",
|
||||
Editable: true,
|
||||
Required: false,
|
||||
},
|
||||
{
|
||||
Key: "port",
|
||||
Value: s.config.Redis.Port,
|
||||
Type: "int",
|
||||
Description: "Redis端口",
|
||||
Editable: true,
|
||||
Required: false,
|
||||
},
|
||||
{
|
||||
Key: "password",
|
||||
Value: "********",
|
||||
Type: "password",
|
||||
Description: "Redis密码",
|
||||
Editable: true,
|
||||
Required: false,
|
||||
},
|
||||
{
|
||||
Key: "db",
|
||||
Value: s.config.Redis.DB,
|
||||
Type: "int",
|
||||
Description: "Redis数据库编号",
|
||||
Editable: true,
|
||||
Required: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 数据源配置
|
||||
if req.Type == "" || req.Type == api.ConfigTypeSource {
|
||||
sections = append(sections, api.ConfigSection{
|
||||
Name: "数据源配置",
|
||||
Type: api.ConfigTypeSource,
|
||||
Description: "股票和期货数据源配置",
|
||||
Items: []api.ConfigItem{
|
||||
{
|
||||
Key: "stock_active",
|
||||
Value: s.config.Sources.Stock.Active,
|
||||
Type: "string",
|
||||
Description: "股票数据源适配器",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Key: "futures_active",
|
||||
Value: s.config.Sources.Futures.Active,
|
||||
Type: "string",
|
||||
Description: "期货数据源适配器",
|
||||
Editable: true,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return &api.ConfigListData{
|
||||
Sections: sections,
|
||||
Version: s.version,
|
||||
Updated: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateConfig 更新配置
|
||||
func (s *ConfigServiceImpl) UpdateConfig(ctx context.Context, req *api.ConfigUpdateRequest) (*api.ConfigUpdateData, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
needRestart := false
|
||||
|
||||
switch req.Type {
|
||||
case api.ConfigTypeServer:
|
||||
if port, ok := req.Items["port"]; ok {
|
||||
s.config.Server.Port = int(port.(float64))
|
||||
needRestart = true
|
||||
}
|
||||
if mode, ok := req.Items["mode"]; ok {
|
||||
s.config.Server.Mode = mode.(string)
|
||||
}
|
||||
if apiKey, ok := req.Items["api_key"]; ok {
|
||||
s.config.Server.APIKey = apiKey.(string)
|
||||
}
|
||||
|
||||
case api.ConfigTypeDatabase:
|
||||
if host, ok := req.Items["host"]; ok {
|
||||
s.config.Database.Host = host.(string)
|
||||
needRestart = true
|
||||
}
|
||||
if port, ok := req.Items["port"]; ok {
|
||||
s.config.Database.Port = int(port.(float64))
|
||||
needRestart = true
|
||||
}
|
||||
if user, ok := req.Items["user"]; ok {
|
||||
s.config.Database.User = user.(string)
|
||||
needRestart = true
|
||||
}
|
||||
if password, ok := req.Items["password"]; ok && password.(string) != "********" {
|
||||
s.config.Database.Password = password.(string)
|
||||
needRestart = true
|
||||
}
|
||||
if database, ok := req.Items["database"]; ok {
|
||||
s.config.Database.Database = database.(string)
|
||||
needRestart = true
|
||||
}
|
||||
|
||||
case api.ConfigTypeSource:
|
||||
if stockActive, ok := req.Items["stock_active"]; ok {
|
||||
s.config.Sources.Stock.Active = stockActive.(string)
|
||||
}
|
||||
if futuresActive, ok := req.Items["futures_active"]; ok {
|
||||
s.config.Sources.Futures.Active = futuresActive.(string)
|
||||
}
|
||||
}
|
||||
|
||||
// 保存到文件
|
||||
if err := s.saveConfig(); err != nil {
|
||||
return &api.ConfigUpdateData{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("配置保存失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 触发回调
|
||||
s.triggerCallbacks(req.Type)
|
||||
|
||||
message := "配置更新成功"
|
||||
if needRestart {
|
||||
message += ",部分配置需要重启服务后生效"
|
||||
}
|
||||
|
||||
return &api.ConfigUpdateData{
|
||||
Success: true,
|
||||
NeedRestart: needRestart,
|
||||
Message: message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReloadConfig 热加载配置
|
||||
func (s *ConfigServiceImpl) ReloadConfig(ctx context.Context, req *api.ReloadRequest) (*api.ReloadData, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 从文件重新加载
|
||||
cfg, err := config.Load(s.configPath)
|
||||
if err != nil {
|
||||
return &api.ReloadData{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("加载配置失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 根据类型选择性更新
|
||||
if req.ConfigType == "" {
|
||||
s.config = cfg
|
||||
} else {
|
||||
switch req.ConfigType {
|
||||
case api.ConfigTypeServer:
|
||||
s.config.Server = cfg.Server
|
||||
case api.ConfigTypeDatabase:
|
||||
s.config.Database = cfg.Database
|
||||
case api.ConfigTypeRedis:
|
||||
s.config.Redis = cfg.Redis
|
||||
case api.ConfigTypeSource:
|
||||
s.config.Sources = cfg.Sources
|
||||
}
|
||||
}
|
||||
|
||||
// 触发回调
|
||||
s.triggerCallbacks(req.ConfigType)
|
||||
|
||||
return &api.ReloadData{
|
||||
Success: true,
|
||||
Message: "配置热加载成功",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSystemStatus 获取系统状态
|
||||
func (s *ConfigServiceImpl) GetSystemStatus(ctx context.Context) (*api.SystemStatusData, error) {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
uptime := time.Since(s.startTime)
|
||||
uptimeStr := formatDuration(uptime)
|
||||
|
||||
return &api.SystemStatusData{
|
||||
Status: "running",
|
||||
Version: s.version,
|
||||
StartTime: s.startTime,
|
||||
Uptime: uptimeStr,
|
||||
GoVersion: runtime.Version(),
|
||||
MemoryUsage: api.MemoryInfo{
|
||||
Alloc: m.Alloc,
|
||||
TotalAlloc: m.TotalAlloc,
|
||||
Sys: m.Sys,
|
||||
NumGC: m.NumGC,
|
||||
},
|
||||
Goroutines: runtime.NumGoroutine(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCurrentConfig 获取当前配置
|
||||
func (s *ConfigServiceImpl) GetCurrentConfig() *config.Config {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.config
|
||||
}
|
||||
|
||||
// RegisterCallback 注册配置变更回调
|
||||
func (s *ConfigServiceImpl) RegisterCallback(configType api.ConfigType, callback func()) {
|
||||
s.cbMu.Lock()
|
||||
defer s.cbMu.Unlock()
|
||||
|
||||
s.callbacks[configType] = append(s.callbacks[configType], callback)
|
||||
}
|
||||
|
||||
// triggerCallbacks 触发回调
|
||||
func (s *ConfigServiceImpl) triggerCallbacks(configType api.ConfigType) {
|
||||
s.cbMu.RLock()
|
||||
defer s.cbMu.RUnlock()
|
||||
|
||||
// 触发特定类型的回调
|
||||
if cbs, ok := s.callbacks[configType]; ok {
|
||||
for _, cb := range cbs {
|
||||
go cb()
|
||||
}
|
||||
}
|
||||
|
||||
// 触发通用回调
|
||||
if cbs, ok := s.callbacks[""]; ok {
|
||||
for _, cb := range cbs {
|
||||
go cb()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到文件
|
||||
func (s *ConfigServiceImpl) saveConfig() error {
|
||||
if s.configPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(s.configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 序列化为JSON
|
||||
data, err := json.MarshalIndent(s.config, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(s.configPath, data, 0644)
|
||||
}
|
||||
|
||||
// getDefaultConfig 获取默认配置
|
||||
func getDefaultConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Port: 8080,
|
||||
Mode: "debug",
|
||||
APIKey: "default-api-key",
|
||||
},
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "user",
|
||||
Password: "password",
|
||||
Database: "marketdata",
|
||||
},
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "",
|
||||
DB: 0,
|
||||
},
|
||||
Sources: config.SourcesConfig{
|
||||
Stock: config.SourceConfig{
|
||||
Active: "tushare",
|
||||
},
|
||||
Futures: config.SourceConfig{
|
||||
Active: "tushare",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// formatDuration 格式化持续时间
|
||||
func formatDuration(d time.Duration) string {
|
||||
days := int(d.Hours()) / 24
|
||||
hours := int(d.Hours()) % 24
|
||||
minutes := int(d.Minutes()) % 60
|
||||
|
||||
if days > 0 {
|
||||
return fmt.Sprintf("%d天%d小时%d分钟", days, hours, minutes)
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%d小时%d分钟", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%d分钟", minutes)
|
||||
}
|
||||
@ -0,0 +1,105 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"market-data-service/api"
|
||||
"market-data-service/internal/repository"
|
||||
)
|
||||
|
||||
// FuturesServiceImpl 期货服务实现
|
||||
type FuturesServiceImpl struct {
|
||||
repo *repository.FuturesRepository
|
||||
}
|
||||
|
||||
// NewFuturesService 创建期货服务
|
||||
func NewFuturesService(repo *repository.FuturesRepository) FuturesService {
|
||||
return &FuturesServiceImpl{
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryKLines 查询K线数据
|
||||
func (s *FuturesServiceImpl) QueryKLines(ctx context.Context, req *api.KLineQueryRequest) (*api.KLineData, error) {
|
||||
// 解析日期
|
||||
start, err := time.Parse("20060102", req.Start)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid start date: %w", err)
|
||||
}
|
||||
end, err := time.Parse("20060102", req.End)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid end date: %w", err)
|
||||
}
|
||||
end = end.Add(24 * time.Hour).Add(-time.Second)
|
||||
|
||||
// 获取K线数据
|
||||
items, err := s.repo.GetKLines(ctx, req.Symbol, req.Freq, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.KLineData{
|
||||
Symbol: req.Symbol,
|
||||
Freq: req.Freq,
|
||||
Count: len(items),
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListSymbols 查询标的列表
|
||||
func (s *FuturesServiceImpl) ListSymbols(ctx context.Context, req *api.SymbolListRequest) (*api.SymbolListData, error) {
|
||||
symbols, total, err := s.repo.ListSymbols(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.SymbolListData{
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
Size: req.Size,
|
||||
Items: symbols,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BatchQueryKLines 批量查询K线
|
||||
func (s *FuturesServiceImpl) BatchQueryKLines(ctx context.Context, req *api.BatchKLineRequest) (*api.BatchKLineData, error) {
|
||||
results := make([]api.BatchKLineResult, len(req.Symbols))
|
||||
|
||||
for i, symbol := range req.Symbols {
|
||||
singleReq := &api.KLineQueryRequest{
|
||||
Symbol: symbol,
|
||||
Start: req.Start,
|
||||
End: req.End,
|
||||
Freq: req.Freq,
|
||||
}
|
||||
|
||||
data, err := s.QueryKLines(ctx, singleReq)
|
||||
results[i] = api.BatchKLineResult{
|
||||
Symbol: symbol,
|
||||
Success: err == nil,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
results[i].Error = err.Error()
|
||||
} else {
|
||||
results[i].Data = &api.KLineSubData{
|
||||
Count: data.Count,
|
||||
Items: data.Items,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &api.BatchKLineData{Results: results}, nil
|
||||
}
|
||||
|
||||
// GetTradingDates 获取交易日历
|
||||
func (s *FuturesServiceImpl) GetTradingDates(ctx context.Context, req *api.TradingDatesRequest) (*api.TradingDatesData, error) {
|
||||
return s.repo.GetTradingDates(ctx, req.Start, req.End)
|
||||
}
|
||||
|
||||
// GetContractsByUnderlying 根据品种获取合约
|
||||
func (s *FuturesServiceImpl) GetContractsByUnderlying(ctx context.Context, req *api.FuturesContractsRequest) (*api.FuturesContractsData, error) {
|
||||
return s.repo.GetContractsByUnderlying(ctx, req.Underlying, req.Exchange)
|
||||
}
|
||||
@ -0,0 +1,132 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"market-data-service/api"
|
||||
"market-data-service/internal/repository"
|
||||
)
|
||||
|
||||
// StockServiceImpl 股票服务实现
|
||||
type StockServiceImpl struct {
|
||||
repo *repository.StockRepository
|
||||
config *DataSourceConfig
|
||||
}
|
||||
|
||||
// DataSourceConfig 数据源配置
|
||||
type DataSourceConfig struct {
|
||||
Adapter interface{}
|
||||
}
|
||||
|
||||
// NewStockService 创建股票服务
|
||||
func NewStockService(repo *repository.StockRepository) StockService {
|
||||
return &StockServiceImpl{
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryKLines 查询K线数据
|
||||
func (s *StockServiceImpl) QueryKLines(ctx context.Context, req *api.KLineQueryRequest) (*api.KLineData, error) {
|
||||
// 解析日期
|
||||
start, err := time.Parse("20060102", req.Start)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid start date: %w", err)
|
||||
}
|
||||
end, err := time.Parse("20060102", req.End)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid end date: %w", err)
|
||||
}
|
||||
end = end.Add(24 * time.Hour).Add(-time.Second) // 包含结束日期全天
|
||||
|
||||
// 获取K线数据
|
||||
items, err := s.repo.GetKLines(ctx, req.Symbol, req.Freq, start, end, req.Adjust)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理复权
|
||||
if req.Adjust != api.AdjustNone {
|
||||
items = s.applyAdjust(ctx, req.Symbol, items, req.Adjust)
|
||||
}
|
||||
|
||||
return &api.KLineData{
|
||||
Symbol: req.Symbol,
|
||||
Freq: req.Freq,
|
||||
Adjust: req.Adjust,
|
||||
Count: len(items),
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// applyAdjust 应用复权
|
||||
func (s *StockServiceImpl) applyAdjust(ctx context.Context, symbol string, items []api.KLineItem, adjustType api.AdjustType) []api.KLineItem {
|
||||
// TODO: 实现复权计算
|
||||
// 1. 从数据库获取复权系数
|
||||
// 2. 根据前复权/后复权计算价格
|
||||
return items
|
||||
}
|
||||
|
||||
// ListSymbols 查询标的列表
|
||||
func (s *StockServiceImpl) ListSymbols(ctx context.Context, req *api.SymbolListRequest) (*api.SymbolListData, error) {
|
||||
symbols, total, err := s.repo.ListSymbols(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.SymbolListData{
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
Size: req.Size,
|
||||
Items: symbols,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BatchQueryKLines 批量查询K线
|
||||
func (s *StockServiceImpl) BatchQueryKLines(ctx context.Context, req *api.BatchKLineRequest) (*api.BatchKLineData, error) {
|
||||
results := make([]api.BatchKLineResult, len(req.Symbols))
|
||||
|
||||
for i, symbol := range req.Symbols {
|
||||
singleReq := &api.KLineQueryRequest{
|
||||
Symbol: symbol,
|
||||
Start: req.Start,
|
||||
End: req.End,
|
||||
Freq: req.Freq,
|
||||
Adjust: req.Adjust,
|
||||
}
|
||||
|
||||
data, err := s.QueryKLines(ctx, singleReq)
|
||||
results[i] = api.BatchKLineResult{
|
||||
Symbol: symbol,
|
||||
Success: err == nil,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
results[i].Error = err.Error()
|
||||
} else {
|
||||
results[i].Data = &api.KLineSubData{
|
||||
Count: data.Count,
|
||||
Items: data.Items,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &api.BatchKLineData{Results: results}, nil
|
||||
}
|
||||
|
||||
// GetTradingDates 获取交易日历
|
||||
func (s *StockServiceImpl) GetTradingDates(ctx context.Context, req *api.TradingDatesRequest) (*api.TradingDatesData, error) {
|
||||
return s.repo.GetTradingDates(ctx, req.Start, req.End)
|
||||
}
|
||||
|
||||
// SyncSymbolsFromSource 从数据源同步标的列表
|
||||
func (s *StockServiceImpl) SyncSymbolsFromSource(ctx context.Context, adapter interface{ FetchSymbols(assetType string) ([]struct {
|
||||
SymbolID string
|
||||
Name string
|
||||
Exchange string
|
||||
}, error) }) error {
|
||||
// TODO: 实现从Tushare同步标的列表
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,519 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"market-data-service/api"
|
||||
)
|
||||
|
||||
// TestService 测试服务接口
|
||||
type TestService interface {
|
||||
// GetAPITestList 获取API测试列表
|
||||
GetAPITestList(ctx context.Context) (*api.APITestListData, error)
|
||||
|
||||
// RunAPITest 执行API测试
|
||||
RunAPITest(ctx context.Context, baseURL string, req *api.APITestRequest) (*api.APITestResult, error)
|
||||
|
||||
// GetWSTestList 获取WebSocket测试列表
|
||||
GetWSTestList(ctx context.Context) (*api.WSTestListData, error)
|
||||
|
||||
// RunWSTest 执行WebSocket测试
|
||||
RunWSTest(ctx context.Context, wsURL string, req *api.WSTestRequest) (*api.WSTestResult, error)
|
||||
|
||||
// GetTestHistory 获取测试历史
|
||||
GetTestHistory(ctx context.Context, req *api.TestHistoryRequest) (*api.TestHistoryData, error)
|
||||
}
|
||||
|
||||
// TestServiceImpl 测试服务实现
|
||||
type TestServiceImpl struct {
|
||||
mu sync.RWMutex
|
||||
apiHistory []api.APITestResult
|
||||
wsHistory []api.WSTestResult
|
||||
historySize int
|
||||
}
|
||||
|
||||
// NewTestService 创建测试服务
|
||||
func NewTestService() TestService {
|
||||
return &TestServiceImpl{
|
||||
apiHistory: make([]api.APITestResult, 0),
|
||||
wsHistory: make([]api.WSTestResult, 0),
|
||||
historySize: 100, // 保留最近100条记录
|
||||
}
|
||||
}
|
||||
|
||||
// GetAPITestList 获取API测试列表
|
||||
func (s *TestServiceImpl) GetAPITestList(ctx context.Context) (*api.APITestListData, error) {
|
||||
categories := []api.APITestCategory{
|
||||
{
|
||||
Name: "股票接口",
|
||||
Items: []api.APITestCase{
|
||||
{
|
||||
ID: "stock_klines",
|
||||
Name: "查询股票K线",
|
||||
Method: "GET",
|
||||
Path: "/v1/stock/klines/{symbol}",
|
||||
Description: "查询指定股票的K线数据",
|
||||
Params: map[string]string{
|
||||
"symbol": "000001.SZ",
|
||||
"start": time.Now().AddDate(0, 0, -30).Format("20060102"),
|
||||
"end": time.Now().Format("20060102"),
|
||||
"freq": "1d",
|
||||
"adjust": "qfq",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "stock_symbols",
|
||||
Name: "查询股票列表",
|
||||
Method: "GET",
|
||||
Path: "/v1/stock/symbols",
|
||||
Description: "获取所有可用股票标的",
|
||||
Params: map[string]string{
|
||||
"page": "1",
|
||||
"size": "20",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "stock_batch",
|
||||
Name: "批量查询股票K线",
|
||||
Method: "POST",
|
||||
Path: "/v1/stock/klines/batch",
|
||||
Description: "批量查询多只股票K线",
|
||||
Body: map[string]interface{}{
|
||||
"symbols": []string{"000001.SZ", "000002.SZ"},
|
||||
"start": time.Now().AddDate(0, 0, -7).Format("20060102"),
|
||||
"end": time.Now().Format("20060102"),
|
||||
"freq": "1d",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "stock_calendar",
|
||||
Name: "查询交易日历",
|
||||
Method: "GET",
|
||||
Path: "/v1/stock/trading-dates",
|
||||
Description: "查询股票交易日历",
|
||||
Params: map[string]string{
|
||||
"start": time.Now().AddDate(0, 0, -30).Format("20060102"),
|
||||
"end": time.Now().AddDate(0, 0, 30).Format("20060102"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "期货接口",
|
||||
Items: []api.APITestCase{
|
||||
{
|
||||
ID: "futures_klines",
|
||||
Name: "查询期货K线",
|
||||
Method: "GET",
|
||||
Path: "/v1/futures/klines/{symbol}",
|
||||
Description: "查询指定期货合约的K线数据",
|
||||
Params: map[string]string{
|
||||
"symbol": "CU2504.SHFE",
|
||||
"start": time.Now().AddDate(0, 0, -30).Format("20060102"),
|
||||
"end": time.Now().Format("20060102"),
|
||||
"freq": "1d",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "futures_symbols",
|
||||
Name: "查询期货列表",
|
||||
Method: "GET",
|
||||
Path: "/v1/futures/symbols",
|
||||
Description: "获取所有可用期货标的",
|
||||
Params: map[string]string{
|
||||
"page": "1",
|
||||
"size": "20",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "futures_batch",
|
||||
Name: "批量查询期货K线",
|
||||
Method: "POST",
|
||||
Path: "/v1/futures/klines/batch",
|
||||
Description: "批量查询多个期货合约K线",
|
||||
Body: map[string]interface{}{
|
||||
"symbols": []string{"CU2504.SHFE", "RB2505.SHFE"},
|
||||
"start": time.Now().AddDate(0, 0, -7).Format("20060102"),
|
||||
"end": time.Now().Format("20060102"),
|
||||
"freq": "1d",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "futures_contracts",
|
||||
Name: "查询合约列表",
|
||||
Method: "GET",
|
||||
Path: "/v1/futures/contracts",
|
||||
Description: "根据品种查询可交易合约",
|
||||
Params: map[string]string{
|
||||
"underlying": "CU",
|
||||
"exchange": "SHFE",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "futures_calendar",
|
||||
Name: "查询期货交易日历",
|
||||
Method: "GET",
|
||||
Path: "/v1/futures/trading-dates",
|
||||
Description: "查询期货交易日历",
|
||||
Params: map[string]string{
|
||||
"start": time.Now().AddDate(0, 0, -30).Format("20060102"),
|
||||
"end": time.Now().AddDate(0, 0, 30).Format("20060102"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "管理接口",
|
||||
Items: []api.APITestCase{
|
||||
{
|
||||
ID: "admin_health",
|
||||
Name: "健康检查",
|
||||
Method: "GET",
|
||||
Path: "/v1/admin/health",
|
||||
Description: "检查服务健康状态",
|
||||
Params: map[string]string{},
|
||||
},
|
||||
{
|
||||
ID: "admin_source_status",
|
||||
Name: "数据源状态",
|
||||
Method: "GET",
|
||||
Path: "/v1/admin/source/status",
|
||||
Description: "获取当前数据源状态",
|
||||
Params: map[string]string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return &api.APITestListData{
|
||||
Categories: categories,
|
||||
BaseURL: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunAPITest 执行API测试
|
||||
func (s *TestServiceImpl) RunAPITest(ctx context.Context, baseURL string, req *api.APITestRequest) (*api.APITestResult, error) {
|
||||
// 获取测试用例
|
||||
testList, _ := s.GetAPITestList(ctx)
|
||||
|
||||
var testCase *api.APITestCase
|
||||
for _, cat := range testList.Categories {
|
||||
for _, item := range cat.Items {
|
||||
if item.ID == req.ID {
|
||||
testCase = &item
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if testCase == nil {
|
||||
return nil, fmt.Errorf("test case not found: %s", req.ID)
|
||||
}
|
||||
|
||||
// 合并参数
|
||||
params := make(map[string]string)
|
||||
for k, v := range testCase.Params {
|
||||
params[k] = v
|
||||
}
|
||||
for k, v := range req.Params {
|
||||
params[k] = v
|
||||
}
|
||||
|
||||
// 构建URL
|
||||
url := baseURL + testCase.Path
|
||||
for k, v := range params {
|
||||
url = strings.Replace(url, "{"+k+"}", v, -1)
|
||||
}
|
||||
|
||||
// 添加查询参数
|
||||
if testCase.Method == "GET" && len(params) > 0 {
|
||||
queryParts := []string{}
|
||||
for k, v := range params {
|
||||
if !strings.Contains(testCase.Path, "{"+k+"}") {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
}
|
||||
if len(queryParts) > 0 {
|
||||
url += "?" + strings.Join(queryParts, "&")
|
||||
}
|
||||
}
|
||||
|
||||
// 准备请求体
|
||||
var body interface{}
|
||||
if req.Body != nil {
|
||||
body = req.Body
|
||||
} else {
|
||||
body = testCase.Body
|
||||
}
|
||||
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
var httpReq *http.Request
|
||||
var err error
|
||||
|
||||
if body != nil && testCase.Method != "GET" {
|
||||
jsonBody, _ := json.Marshal(body)
|
||||
httpReq, err = http.NewRequestWithContext(ctx, testCase.Method, url, bytes.NewBuffer(jsonBody))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
httpReq, err = http.NewRequestWithContext(ctx, testCase.Method, url, nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq.Header.Set("X-API-Key", "test-api-key")
|
||||
|
||||
// 执行请求
|
||||
startTime := time.Now()
|
||||
resp, err := client.Do(httpReq)
|
||||
latency := time.Since(startTime).Milliseconds()
|
||||
|
||||
result := &api.APITestResult{
|
||||
ID: int(time.Now().Unix()),
|
||||
CaseID: req.ID,
|
||||
Name: testCase.Name,
|
||||
Latency: latency,
|
||||
Timestamp: time.Now(),
|
||||
Request: map[string]interface{}{
|
||||
"method": testCase.Method,
|
||||
"url": url,
|
||||
"body": body,
|
||||
},
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result.Success = false
|
||||
result.Error = err.Error()
|
||||
s.addAPIHistory(result)
|
||||
return result, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result.StatusCode = resp.StatusCode
|
||||
result.Success = resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
|
||||
// 解析响应
|
||||
var respBody interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&respBody); err == nil {
|
||||
result.Response = respBody
|
||||
} else {
|
||||
result.Response = map[string]string{"raw": "非JSON响应"}
|
||||
}
|
||||
|
||||
s.addAPIHistory(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetWSTestList 获取WebSocket测试列表
|
||||
func (s *TestServiceImpl) GetWSTestList(ctx context.Context) (*api.WSTestListData, error) {
|
||||
cases := []api.WSTestCase{
|
||||
{
|
||||
ID: "ws_subscribe_stock",
|
||||
Name: "订阅股票行情",
|
||||
Description: "订阅单只股票实时行情",
|
||||
Action: "subscribe",
|
||||
Symbols: []string{"000001.SZ"},
|
||||
},
|
||||
{
|
||||
ID: "ws_subscribe_futures",
|
||||
Name: "订阅期货行情",
|
||||
Description: "订阅单个期货合约实时行情",
|
||||
Action: "subscribe",
|
||||
Symbols: []string{"CU2504.SHFE"},
|
||||
},
|
||||
{
|
||||
ID: "ws_subscribe_multi",
|
||||
Name: "批量订阅",
|
||||
Description: "同时订阅多个标的",
|
||||
Action: "subscribe",
|
||||
Symbols: []string{"000001.SZ", "000002.SZ", "CU2504.SHFE"},
|
||||
},
|
||||
{
|
||||
ID: "ws_unsubscribe",
|
||||
Name: "取消订阅",
|
||||
Description: "取消订阅标的",
|
||||
Action: "unsubscribe",
|
||||
Symbols: []string{"000001.SZ"},
|
||||
},
|
||||
}
|
||||
|
||||
return &api.WSTestListData{
|
||||
Cases: cases,
|
||||
WSURL: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunWSTest 执行WebSocket测试
|
||||
func (s *TestServiceImpl) RunWSTest(ctx context.Context, wsURL string, req *api.WSTestRequest) (*api.WSTestResult, error) {
|
||||
// 获取测试用例
|
||||
testList, _ := s.GetWSTestList(ctx)
|
||||
|
||||
var testCase *api.WSTestCase
|
||||
for _, item := range testList.Cases {
|
||||
if item.ID == req.ID {
|
||||
testCase = &item
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if testCase == nil {
|
||||
return nil, fmt.Errorf("test case not found: %s", req.ID)
|
||||
}
|
||||
|
||||
// 使用自定义标的
|
||||
symbols := testCase.Symbols
|
||||
if len(req.Symbols) > 0 {
|
||||
symbols = req.Symbols
|
||||
}
|
||||
|
||||
result := &api.WSTestResult{
|
||||
ID: fmt.Sprintf("ws_%d", time.Now().Unix()),
|
||||
CaseID: req.ID,
|
||||
Timestamp: time.Now(),
|
||||
Messages: []api.WSMessage{},
|
||||
}
|
||||
|
||||
// 连接WebSocket
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// 将 http:// 或 https:// 替换为 ws:// 或 wss://
|
||||
if strings.HasPrefix(wsURL, "https://") {
|
||||
wsURL = strings.Replace(wsURL, "https://", "wss://", 1)
|
||||
} else if strings.HasPrefix(wsURL, "http://") {
|
||||
wsURL = strings.Replace(wsURL, "http://", "ws://", 1)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
conn, resp, err := dialer.Dial(wsURL, http.Header{
|
||||
"X-API-Key": []string{"test-api-key"},
|
||||
})
|
||||
result.Latency = time.Since(startTime).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
result.Success = false
|
||||
if resp != nil {
|
||||
result.Error = fmt.Sprintf("连接失败,状态码: %d", resp.StatusCode)
|
||||
} else {
|
||||
result.Error = fmt.Sprintf("连接失败: %v", err)
|
||||
}
|
||||
s.addWSHistory(result)
|
||||
return result, nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
result.Success = true
|
||||
|
||||
// 发送订阅消息
|
||||
msg := map[string]interface{}{
|
||||
"action": testCase.Action,
|
||||
"symbols": symbols,
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(msg); err != nil {
|
||||
result.Error = fmt.Sprintf("发送消息失败: %v", err)
|
||||
s.addWSHistory(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 等待响应
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
for i := 0; i < 3; i++ { // 最多读取3条消息
|
||||
var msgData map[string]interface{}
|
||||
if err := conn.ReadJSON(&msgData); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
result.Messages = append(result.Messages, api.WSMessage{
|
||||
Type: "received",
|
||||
Data: msgData,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
s.addWSHistory(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTestHistory 获取测试历史
|
||||
func (s *TestServiceImpl) GetTestHistory(ctx context.Context, req *api.TestHistoryRequest) (*api.TestHistoryData, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
limit := req.Limit
|
||||
if limit <= 0 || limit > len(s.apiHistory) {
|
||||
limit = len(s.apiHistory)
|
||||
}
|
||||
|
||||
// 获取最近的数据
|
||||
apiTests := make([]api.APITestResult, 0)
|
||||
wsTests := make([]api.WSTestResult, 0)
|
||||
|
||||
if req.Type == "" || req.Type == "api" {
|
||||
start := len(s.apiHistory) - limit
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
apiTests = append(apiTests, s.apiHistory[start:]...)
|
||||
}
|
||||
|
||||
if req.Type == "" || req.Type == "ws" {
|
||||
wsLimit := req.Limit
|
||||
if wsLimit <= 0 || wsLimit > len(s.wsHistory) {
|
||||
wsLimit = len(s.wsHistory)
|
||||
}
|
||||
start := len(s.wsHistory) - wsLimit
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
wsTests = append(wsTests, s.wsHistory[start:]...)
|
||||
}
|
||||
|
||||
return &api.TestHistoryData{
|
||||
APITests: apiTests,
|
||||
WSTests: wsTests,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// addAPIHistory 添加API测试历史
|
||||
func (s *TestServiceImpl) addAPIHistory(result *api.APITestResult) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.apiHistory = append(s.apiHistory, *result)
|
||||
|
||||
// 限制历史记录数量
|
||||
if len(s.apiHistory) > s.historySize {
|
||||
s.apiHistory = s.apiHistory[len(s.apiHistory)-s.historySize:]
|
||||
}
|
||||
}
|
||||
|
||||
// addWSHistory 添加WebSocket测试历史
|
||||
func (s *TestServiceImpl) addWSHistory(result *api.WSTestResult) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.wsHistory = append(s.wsHistory, *result)
|
||||
|
||||
// 限制历史记录数量
|
||||
if len(s.wsHistory) > s.historySize {
|
||||
s.wsHistory = s.wsHistory[len(s.wsHistory)-s.historySize:]
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,374 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"market-data-service/api"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许所有来源,生产环境需要限制
|
||||
},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
// Hub WebSocket连接管理中心
|
||||
type Hub struct {
|
||||
// 已注册的客户端
|
||||
clients map[*Client]bool
|
||||
|
||||
// 广播消息通道
|
||||
broadcast chan []byte
|
||||
|
||||
// 注册请求通道
|
||||
register chan *Client
|
||||
|
||||
// 注销请求通道
|
||||
unregister chan *Client
|
||||
|
||||
// 标的订阅映射: symbol -> clients
|
||||
subscriptions map[string]map[*Client]bool
|
||||
|
||||
// 保护subscriptions的锁
|
||||
subMu sync.RWMutex
|
||||
|
||||
// 最大订阅标的数
|
||||
maxSymbolsPerClient int
|
||||
}
|
||||
|
||||
// NewHub 创建Hub
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*Client]bool),
|
||||
broadcast: make(chan []byte),
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
subscriptions: make(map[string]map[*Client]bool),
|
||||
maxSymbolsPerClient: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Run 启动Hub
|
||||
func (h *Hub) Run() {
|
||||
for {
|
||||
select {
|
||||
case client := <-h.register:
|
||||
h.clients[client] = true
|
||||
log.Printf("Client registered, total: %d", len(h.clients))
|
||||
|
||||
case client := <-h.unregister:
|
||||
if _, ok := h.clients[client]; ok {
|
||||
delete(h.clients, client)
|
||||
close(client.send)
|
||||
// 清理订阅
|
||||
h.removeAllSubscriptions(client)
|
||||
log.Printf("Client unregistered, total: %d", len(h.clients))
|
||||
}
|
||||
|
||||
case message := <-h.broadcast:
|
||||
for client := range h.clients {
|
||||
select {
|
||||
case client.send <- message:
|
||||
default:
|
||||
// 发送缓冲满,关闭连接
|
||||
close(client.send)
|
||||
delete(h.clients, client)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe 客户端订阅标的
|
||||
func (h *Hub) Subscribe(client *Client, symbols []string) error {
|
||||
if len(client.subscriptions)+len(symbols) > h.maxSymbolsPerClient {
|
||||
return api.ErrRateLimit
|
||||
}
|
||||
|
||||
h.subMu.Lock()
|
||||
defer h.subMu.Unlock()
|
||||
|
||||
for _, symbol := range symbols {
|
||||
if _, ok := h.subscriptions[symbol]; !ok {
|
||||
h.subscriptions[symbol] = make(map[*Client]bool)
|
||||
}
|
||||
h.subscriptions[symbol][client] = true
|
||||
client.subscriptions[symbol] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe 客户端取消订阅
|
||||
func (h *Hub) Unsubscribe(client *Client, symbols []string) {
|
||||
h.subMu.Lock()
|
||||
defer h.subMu.Unlock()
|
||||
|
||||
for _, symbol := range symbols {
|
||||
if clients, ok := h.subscriptions[symbol]; ok {
|
||||
delete(clients, client)
|
||||
if len(clients) == 0 {
|
||||
delete(h.subscriptions, symbol)
|
||||
}
|
||||
}
|
||||
delete(client.subscriptions, symbol)
|
||||
}
|
||||
}
|
||||
|
||||
// removeAllSubscriptions 移除客户端所有订阅
|
||||
func (h *Hub) removeAllSubscriptions(client *Client) {
|
||||
h.subMu.Lock()
|
||||
defer h.subMu.Unlock()
|
||||
|
||||
for symbol := range client.subscriptions {
|
||||
if clients, ok := h.subscriptions[symbol]; ok {
|
||||
delete(clients, client)
|
||||
if len(clients) == 0 {
|
||||
delete(h.subscriptions, symbol)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastToSymbol 向订阅了某标的的所有客户端广播
|
||||
func (h *Hub) BroadcastToSymbol(symbol string, data []byte) {
|
||||
h.subMu.RLock()
|
||||
clients := h.subscriptions[symbol]
|
||||
h.subMu.RUnlock()
|
||||
|
||||
for client := range clients {
|
||||
select {
|
||||
case client.send <- data:
|
||||
default:
|
||||
// 发送缓冲满,稍后处理
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetSubscriptionStats 获取订阅统计
|
||||
func (h *Hub) GetSubscriptionStats() map[string]interface{} {
|
||||
h.subMu.RLock()
|
||||
defer h.subMu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_clients": len(h.clients),
|
||||
"total_subscriptions": len(h.subscriptions),
|
||||
}
|
||||
}
|
||||
|
||||
// Client WebSocket客户端
|
||||
type Client struct {
|
||||
hub *Hub
|
||||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
|
||||
// 已订阅的标的
|
||||
subscriptions map[string]bool
|
||||
subMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient 创建客户端
|
||||
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
|
||||
return &Client{
|
||||
hub: hub,
|
||||
conn: conn,
|
||||
send: make(chan []byte, 256),
|
||||
subscriptions: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPump 读取客户端消息
|
||||
func (c *Client) ReadPump() {
|
||||
defer func() {
|
||||
c.hub.unregister <- c
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
c.conn.SetReadLimit(512 * 1024) // 512KB
|
||||
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 处理客户端消息
|
||||
c.handleMessage(message)
|
||||
}
|
||||
}
|
||||
|
||||
// WritePump 向客户端写入消息
|
||||
func (c *Client) WritePump() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if !ok {
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
c.conn.WriteMessage(websocket.TextMessage, message)
|
||||
|
||||
case <-ticker.C:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleMessage 处理客户端消息
|
||||
func (c *Client) handleMessage(data []byte) {
|
||||
var msg ClientMessage
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
c.sendError(1000, "Invalid message format")
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Action {
|
||||
case "subscribe":
|
||||
c.handleSubscribe(msg.Symbols)
|
||||
case "unsubscribe":
|
||||
c.handleUnsubscribe(msg.Symbols)
|
||||
default:
|
||||
c.sendError(1001, "Unknown action")
|
||||
}
|
||||
}
|
||||
|
||||
// handleSubscribe 处理订阅请求
|
||||
func (c *Client) handleSubscribe(symbols []string) {
|
||||
if len(symbols) == 0 {
|
||||
c.sendError(1002, "Symbols cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.hub.Subscribe(c, symbols); err != nil {
|
||||
c.sendError(1003, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 发送确认
|
||||
ack := map[string]interface{}{
|
||||
"type": "ack",
|
||||
"action": "subscribe",
|
||||
"symbols": symbols,
|
||||
"ts": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
data, _ := json.Marshal(ack)
|
||||
c.send <- data
|
||||
}
|
||||
|
||||
// handleUnsubscribe 处理取消订阅请求
|
||||
func (c *Client) handleUnsubscribe(symbols []string) {
|
||||
c.hub.Unsubscribe(c, symbols)
|
||||
|
||||
ack := map[string]interface{}{
|
||||
"type": "ack",
|
||||
"action": "unsubscribe",
|
||||
"symbols": symbols,
|
||||
"ts": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
data, _ := json.Marshal(ack)
|
||||
c.send <- data
|
||||
}
|
||||
|
||||
// sendError 发送错误消息
|
||||
func (c *Client) sendError(code int, message string) {
|
||||
err := map[string]interface{}{
|
||||
"type": "error",
|
||||
"code": code,
|
||||
"message": message,
|
||||
"ts": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
data, _ := json.Marshal(err)
|
||||
c.send <- data
|
||||
}
|
||||
|
||||
// ClientMessage 客户端消息结构
|
||||
type ClientMessage struct {
|
||||
Action string `json:"action"`
|
||||
Symbols []string `json:"symbols"`
|
||||
}
|
||||
|
||||
// Server WebSocket服务器
|
||||
type Server struct {
|
||||
hub *Hub
|
||||
}
|
||||
|
||||
// NewServer 创建WebSocket服务器
|
||||
func NewServer(hub *Hub) *Server {
|
||||
return &Server{hub: hub}
|
||||
}
|
||||
|
||||
// HandleWebSocket 处理WebSocket连接
|
||||
func (s *Server) HandleWebSocket(c *gin.Context) {
|
||||
// 认证检查
|
||||
apiKey := c.GetHeader("X-API-Key")
|
||||
if apiKey == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing API Key"})
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
client := NewClient(s.hub, conn)
|
||||
s.hub.register <- client
|
||||
|
||||
go client.WritePump()
|
||||
go client.ReadPump()
|
||||
}
|
||||
|
||||
// BroadcastTick 广播Tick数据
|
||||
func (s *Server) BroadcastTick(symbol string, tick map[string]interface{}) {
|
||||
data, err := json.Marshal(tick)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.hub.BroadcastToSymbol(symbol, data)
|
||||
}
|
||||
|
||||
// BroadcastKLine 广播K线闭合数据
|
||||
func (s *Server) BroadcastKLine(symbol string, freq string, kline map[string]interface{}) {
|
||||
msg := map[string]interface{}{
|
||||
"type": "klines",
|
||||
"symbol": symbol,
|
||||
"freq": freq,
|
||||
"data": kline,
|
||||
"ts": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.hub.BroadcastToSymbol(symbol, data)
|
||||
}
|
||||
@ -0,0 +1,73 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Config 配置管理
|
||||
type Config struct {
|
||||
Server ServerConfig `json:"server" yaml:"server"`
|
||||
Database DatabaseConfig `json:"database" yaml:"database"`
|
||||
Redis RedisConfig `json:"redis" yaml:"redis"`
|
||||
Sources SourcesConfig `json:"sources" yaml:"sources"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `json:"port" yaml:"port"`
|
||||
Mode string `json:"mode" yaml:"mode"` // debug/release
|
||||
APIKey string `json:"api_key" yaml:"api_key"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Host string `json:"host" yaml:"host"`
|
||||
Port int `json:"port" yaml:"port"`
|
||||
User string `json:"user" yaml:"user"`
|
||||
Password string `json:"password" yaml:"password"`
|
||||
Database string `json:"database" yaml:"database"`
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Host string `json:"host" yaml:"host"`
|
||||
Port int `json:"port" yaml:"port"`
|
||||
Password string `json:"password" yaml:"password"`
|
||||
DB int `json:"db" yaml:"db"`
|
||||
}
|
||||
|
||||
type SourcesConfig struct {
|
||||
Stock SourceConfig `json:"stock" yaml:"stock"`
|
||||
Futures SourceConfig `json:"futures" yaml:"futures"`
|
||||
}
|
||||
|
||||
type SourceConfig struct {
|
||||
Active string `json:"active" yaml:"active"`
|
||||
Sources map[string]SourceInfo `json:"list" yaml:"list"`
|
||||
}
|
||||
|
||||
type SourceInfo struct {
|
||||
Type string `json:"type" yaml:"type"`
|
||||
Config map[string]string `json:"config" yaml:"config"`
|
||||
}
|
||||
|
||||
// Load 加载配置
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if cfg.Server.Port == 0 {
|
||||
cfg.Server.Port = 8080
|
||||
}
|
||||
if cfg.Server.Mode == "" {
|
||||
cfg.Server.Mode = "debug"
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
package errors
|
||||
|
||||
import "errors"
|
||||
|
||||
// 业务错误定义
|
||||
|
||||
var (
|
||||
// 参数错误
|
||||
ErrInvalidParam = errors.New("参数错误")
|
||||
ErrInvalidSymbol = errors.New("无效的标的代码")
|
||||
ErrInvalidDate = errors.New("无效的日期格式")
|
||||
|
||||
// 数据错误
|
||||
ErrSymbolNotFound = errors.New("标的不存在")
|
||||
ErrDataNotFound = errors.New("数据不存在")
|
||||
ErrDataSourceUnavailable = errors.New("数据源不可用")
|
||||
|
||||
// 权限错误
|
||||
ErrUnauthorized = errors.New("未授权")
|
||||
ErrRateLimit = errors.New("请求过于频繁")
|
||||
|
||||
// 系统错误
|
||||
ErrInternal = errors.New("服务器内部错误")
|
||||
)
|
||||
|
||||
// ErrorCode 错误码
|
||||
type ErrorCode int
|
||||
|
||||
const (
|
||||
CodeOK ErrorCode = 0
|
||||
CodeBadRequest ErrorCode = 400
|
||||
CodeUnauthorized ErrorCode = 401
|
||||
CodeNotFound ErrorCode = 404
|
||||
CodeRateLimit ErrorCode = 429
|
||||
CodeInternal ErrorCode = 500
|
||||
)
|
||||
@ -0,0 +1,20 @@
|
||||
package logger
|
||||
|
||||
import "log"
|
||||
|
||||
// Logger 日志工具
|
||||
|
||||
// Info 信息日志
|
||||
func Info(format string, v ...interface{}) {
|
||||
log.Printf("[INFO] "+format, v...)
|
||||
}
|
||||
|
||||
// Error 错误日志
|
||||
func Error(format string, v ...interface{}) {
|
||||
log.Printf("[ERROR] "+format, v...)
|
||||
}
|
||||
|
||||
// Debug 调试日志
|
||||
func Debug(format string, v ...interface{}) {
|
||||
log.Printf("[DEBUG] "+format, v...)
|
||||
}
|
||||
@ -0,0 +1,2 @@
|
||||
"""Market Data Service - Python实现"""
|
||||
__version__ = "1.0.0"
|
||||
@ -0,0 +1,13 @@
|
||||
"""数据源适配器模块"""
|
||||
from .base import DataSourceAdapter, TickData, KLineData, SymbolInfo, TradeCalData, TickCallback
|
||||
from .tushare_adapter import TushareAdapter
|
||||
|
||||
__all__ = [
|
||||
"DataSourceAdapter",
|
||||
"TickData",
|
||||
"KLineData",
|
||||
"SymbolInfo",
|
||||
"TradeCalData",
|
||||
"TickCallback",
|
||||
"TushareAdapter",
|
||||
]
|
||||
@ -0,0 +1,102 @@
|
||||
"""数据源适配器基类 - 对应Go的adapter/adapter.go"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickData:
|
||||
"""Tick数据"""
|
||||
symbol: str
|
||||
price: float
|
||||
volume: int
|
||||
time: int # Unix时间戳
|
||||
|
||||
|
||||
@dataclass
|
||||
class KLineData:
|
||||
"""K线数据"""
|
||||
symbol: str
|
||||
time: int # Unix时间戳
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: int
|
||||
amount: float
|
||||
open_interest: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""标的信息"""
|
||||
symbol_id: str
|
||||
name: str
|
||||
exchange: str
|
||||
underlying: str = "" # 期货品种代码
|
||||
contract_month: str = ""
|
||||
list_date: str = ""
|
||||
delist_date: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeCalData:
|
||||
"""交易日历数据"""
|
||||
date: datetime
|
||||
is_trading_day: bool
|
||||
has_night_session: bool = False
|
||||
|
||||
|
||||
# Tick数据回调类型
|
||||
TickCallback = Callable[[str, TickData], None]
|
||||
|
||||
|
||||
class DataSourceAdapter(ABC):
|
||||
"""数据源适配器接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self, config: dict) -> None:
|
||||
"""建立连接"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe_ticks(self, symbols: List[str], callback: TickCallback) -> None:
|
||||
"""订阅实时Tick"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_klines(
|
||||
self,
|
||||
symbol: str,
|
||||
start: str,
|
||||
end: str,
|
||||
freq: str
|
||||
) -> List[KLineData]:
|
||||
"""拉取历史K线"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_symbols(self, asset_type: str) -> List[SymbolInfo]:
|
||||
"""获取标的列表"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def fetch_trading_calendar(
|
||||
self,
|
||||
exchange: str,
|
||||
start: str,
|
||||
end: str
|
||||
) -> List[TradeCalData]:
|
||||
"""获取交易日历"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
pass
|
||||
@ -0,0 +1,5 @@
|
||||
"""API路由模块"""
|
||||
from .routes import router
|
||||
from .admin_routes import admin_router
|
||||
|
||||
__all__ = ["router", "admin_router"]
|
||||
@ -0,0 +1,232 @@
|
||||
"""管理后台API路由 - 对应Go的api/admin_router.go"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Query
|
||||
from typing import Optional
|
||||
|
||||
from app.models import (
|
||||
Response, ConfigListRequest, ConfigUpdateRequest,
|
||||
ReloadRequest, AdapterToggleRequest, AdapterConfigUpdateRequest,
|
||||
APITestRequest, WSTestRequest, TestHistoryRequest
|
||||
)
|
||||
from app.services import ConfigService, AdapterService, TestService
|
||||
from app.core.config import get_config
|
||||
|
||||
admin_router = APIRouter()
|
||||
|
||||
# 服务实例
|
||||
config_service = ConfigService()
|
||||
adapter_service = AdapterService()
|
||||
test_service = TestService()
|
||||
|
||||
|
||||
def verify_admin_token(x_admin_token: Optional[str] = Header(None)):
|
||||
"""验证Admin Token"""
|
||||
# TODO: 实现Token验证
|
||||
return x_admin_token
|
||||
|
||||
|
||||
# ============================================
|
||||
# 系统管理接口
|
||||
# ============================================
|
||||
|
||||
@admin_router.get("/admin/system/status", response_model=Response)
|
||||
def get_system_status(
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""获取系统状态"""
|
||||
try:
|
||||
data = config_service.get_system_status()
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/admin/system/reload", response_model=Response)
|
||||
def reload_config(
|
||||
req: Optional[ReloadRequest] = None,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""热加载配置"""
|
||||
try:
|
||||
if req is None:
|
||||
req = ReloadRequest()
|
||||
data = config_service.reload_config(req)
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/admin/system/restart", response_model=Response)
|
||||
def restart_service(
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""重启服务"""
|
||||
# TODO: 实现服务重启逻辑
|
||||
return Response(
|
||||
code=0,
|
||||
message="重启命令已发送",
|
||||
data={"status": "restarting"}
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 配置管理接口
|
||||
# ============================================
|
||||
|
||||
@admin_router.get("/admin/config", response_model=Response)
|
||||
def get_config_list(
|
||||
type: Optional[str] = Query(None, description="配置类型筛选"),
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""获取配置列表"""
|
||||
try:
|
||||
from app.models import ConfigType
|
||||
req = ConfigListRequest()
|
||||
if type:
|
||||
req.type = ConfigType(type)
|
||||
|
||||
data = config_service.get_config_list(req)
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.put("/admin/config", response_model=Response)
|
||||
def update_config(
|
||||
req: ConfigUpdateRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""更新配置"""
|
||||
try:
|
||||
data = config_service.update_config(req)
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/admin/config/reload", response_model=Response)
|
||||
def reload_config_endpoint(
|
||||
req: Optional[ReloadRequest] = None,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""热加载配置"""
|
||||
return reload_config(req, token)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 适配器管理接口
|
||||
# ============================================
|
||||
|
||||
@admin_router.get("/admin/adapters", response_model=Response)
|
||||
def get_adapter_list(
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""获取适配器列表"""
|
||||
try:
|
||||
data = adapter_service.get_adapter_list()
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/admin/adapters/toggle", response_model=Response)
|
||||
def toggle_adapter(
|
||||
req: AdapterToggleRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""切换适配器状态"""
|
||||
try:
|
||||
adapter_service.toggle_adapter(req)
|
||||
return Response(code=0, message="success")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.put("/admin/adapters/config", response_model=Response)
|
||||
def update_adapter_config(
|
||||
req: AdapterConfigUpdateRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""更新适配器配置"""
|
||||
try:
|
||||
adapter_service.update_adapter_config(req)
|
||||
return Response(code=0, message="success")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================
|
||||
# 测试管理接口
|
||||
# ============================================
|
||||
|
||||
@admin_router.get("/admin/tests/api", response_model=Response)
|
||||
def get_api_test_list(
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""获取API测试列表"""
|
||||
try:
|
||||
data = test_service.get_api_test_list()
|
||||
# 设置基础URL
|
||||
config = get_config()
|
||||
data.base_url = f"http://localhost:{config.server.port}"
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/admin/tests/api/run", response_model=Response)
|
||||
async def run_api_test(
|
||||
req: APITestRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""执行API测试"""
|
||||
try:
|
||||
config = get_config()
|
||||
base_url = f"http://localhost:{config.server.port}"
|
||||
data = await test_service.run_api_test(base_url, req)
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.get("/admin/tests/ws", response_model=Response)
|
||||
def get_ws_test_list(
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""获取WebSocket测试列表"""
|
||||
try:
|
||||
data = test_service.get_ws_test_list()
|
||||
config = get_config()
|
||||
data.ws_url = f"ws://localhost:{config.server.port}/v1/stream"
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/admin/tests/ws/run", response_model=Response)
|
||||
async def run_ws_test(
|
||||
req: WSTestRequest,
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""执行WebSocket测试"""
|
||||
try:
|
||||
config = get_config()
|
||||
ws_url = f"ws://localhost:{config.server.port}/v1/stream"
|
||||
data = await test_service.run_ws_test(ws_url, req)
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.get("/admin/tests/history", response_model=Response)
|
||||
def get_test_history(
|
||||
type: Optional[str] = Query(None, description="测试类型"),
|
||||
limit: int = Query(default=20, ge=1, le=100, description="数量限制"),
|
||||
token: str = Depends(verify_admin_token)
|
||||
):
|
||||
"""获取测试历史"""
|
||||
try:
|
||||
req = TestHistoryRequest(type=type, limit=limit)
|
||||
data = test_service.get_test_history(req)
|
||||
return Response(code=0, message="success", data=data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@ -0,0 +1,131 @@
|
||||
"""配置管理模块"""
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
"""服务器配置"""
|
||||
port: int = 8080
|
||||
mode: str = "debug" # debug/release
|
||||
api_key: str = "demo-api-key-2024"
|
||||
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
"""数据库配置"""
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
user: str = "postgres"
|
||||
password: str = "postgres"
|
||||
database: str = "marketdata"
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
||||
|
||||
|
||||
class RedisConfig(BaseModel):
|
||||
"""Redis配置"""
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
password: str = ""
|
||||
db: int = 0
|
||||
|
||||
|
||||
class SourceInfo(BaseModel):
|
||||
"""数据源信息"""
|
||||
type: str = "http"
|
||||
config: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SourceConfig(BaseModel):
|
||||
"""源配置"""
|
||||
active: str = "tushare"
|
||||
list: Dict[str, SourceInfo] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SourcesConfig(BaseModel):
|
||||
"""数据源配置"""
|
||||
stock: SourceConfig = Field(default_factory=lambda: SourceConfig(
|
||||
active="tushare",
|
||||
list={"tushare": SourceInfo(type="http", config={"base_url": "https://api.tushare.pro"})}
|
||||
))
|
||||
futures: SourceConfig = Field(default_factory=lambda: SourceConfig(
|
||||
active="tushare",
|
||||
list={"tushare": SourceInfo(type="http", config={"base_url": "https://api.tushare.pro"})}
|
||||
))
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""主配置类"""
|
||||
server: ServerConfig = Field(default_factory=ServerConfig)
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||
redis: RedisConfig = Field(default_factory=RedisConfig)
|
||||
sources: SourcesConfig = Field(default_factory=SourcesConfig)
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""环境变量配置"""
|
||||
port: int = Field(default=8080, alias="PORT")
|
||||
database_url: Optional[str] = Field(default=None, alias="DATABASE_URL")
|
||||
tushare_token: Optional[str] = Field(default=None, alias="TUSHARE_TOKEN")
|
||||
api_key: Optional[str] = Field(default=None, alias="API_KEY")
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
def load_config(config_path: str = "./config.json") -> Config:
|
||||
"""从文件加载配置"""
|
||||
if not os.path.exists(config_path):
|
||||
return Config()
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return Config.model_validate(data)
|
||||
|
||||
|
||||
def save_config(config: Config, config_path: str = "./config.json"):
|
||||
"""保存配置到文件"""
|
||||
os.makedirs(os.path.dirname(config_path) or '.', exist_ok=True)
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(config.model_dump(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
_config: Optional[Config] = None
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""获取当前配置"""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = load_config()
|
||||
return _config
|
||||
|
||||
|
||||
def set_config(config: Config):
|
||||
"""设置全局配置"""
|
||||
global _config
|
||||
_config = config
|
||||
|
||||
|
||||
def reload_config(config_path: str = "./config.json") -> Config:
|
||||
"""重新加载配置"""
|
||||
global _config
|
||||
_config = load_config(config_path)
|
||||
return _config
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""获取环境变量设置"""
|
||||
return Settings()
|
||||
@ -0,0 +1,78 @@
|
||||
"""错误定义模块"""
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
|
||||
class ErrorCode(IntEnum):
|
||||
"""错误码"""
|
||||
OK = 0
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
NOT_FOUND = 404
|
||||
RATE_LIMIT = 429
|
||||
INTERNAL = 500
|
||||
|
||||
|
||||
class AppException(Exception):
|
||||
"""应用异常基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.INTERNAL,
|
||||
detail: Optional[str] = None
|
||||
):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.detail = detail
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"code": int(self.code),
|
||||
"message": self.message,
|
||||
"detail": self.detail
|
||||
}
|
||||
|
||||
|
||||
# 参数错误
|
||||
class InvalidParamError(AppException):
|
||||
def __init__(self, message: str = "参数错误", detail: Optional[str] = None):
|
||||
super().__init__(message, ErrorCode.BAD_REQUEST, detail)
|
||||
|
||||
|
||||
class InvalidSymbolError(AppException):
|
||||
def __init__(self, message: str = "无效的标的代码"):
|
||||
super().__init__(message, ErrorCode.BAD_REQUEST)
|
||||
|
||||
|
||||
class InvalidDateError(AppException):
|
||||
def __init__(self, message: str = "无效的日期格式"):
|
||||
super().__init__(message, ErrorCode.BAD_REQUEST)
|
||||
|
||||
|
||||
# 数据错误
|
||||
class SymbolNotFoundError(AppException):
|
||||
def __init__(self, message: str = "标的不存在"):
|
||||
super().__init__(message, ErrorCode.NOT_FOUND)
|
||||
|
||||
|
||||
class DataNotFoundError(AppException):
|
||||
def __init__(self, message: str = "数据不存在"):
|
||||
super().__init__(message, ErrorCode.NOT_FOUND)
|
||||
|
||||
|
||||
class DataSourceUnavailableError(AppException):
|
||||
def __init__(self, message: str = "数据源不可用"):
|
||||
super().__init__(message, ErrorCode.INTERNAL)
|
||||
|
||||
|
||||
# 权限错误
|
||||
class UnauthorizedError(AppException):
|
||||
def __init__(self, message: str = "未授权"):
|
||||
super().__init__(message, ErrorCode.UNAUTHORIZED)
|
||||
|
||||
|
||||
class RateLimitError(AppException):
|
||||
def __init__(self, message: str = "请求过于频繁"):
|
||||
super().__init__(message, ErrorCode.RATE_LIMIT)
|
||||
@ -0,0 +1,47 @@
|
||||
"""日志工具模块"""
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: int = logging.INFO,
|
||||
format_string: Optional[str] = None
|
||||
) -> logging.Logger:
|
||||
"""设置日志配置"""
|
||||
if format_string is None:
|
||||
format_string = "%(asctime)s | %(levelname)-8s | %(message)s"
|
||||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format=format_string,
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
return logging.getLogger("market_data")
|
||||
|
||||
|
||||
# 全局logger实例
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def info(msg: str, *args, **kwargs):
|
||||
"""信息日志"""
|
||||
logger.info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg: str, *args, **kwargs):
|
||||
"""错误日志"""
|
||||
logger.error(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def debug(msg: str, *args, **kwargs):
|
||||
"""调试日志"""
|
||||
logger.debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg: str, *args, **kwargs):
|
||||
"""警告日志"""
|
||||
logger.warning(msg, *args, **kwargs)
|
||||
@ -0,0 +1,132 @@
|
||||
"""数据模型模块"""
|
||||
from .types import (
|
||||
Frequency,
|
||||
AdjustType,
|
||||
AssetClass,
|
||||
SymbolType,
|
||||
Exchange,
|
||||
DataSourceStatus,
|
||||
KLineItem,
|
||||
KLineData,
|
||||
KLineQueryRequest,
|
||||
BatchKLineRequest,
|
||||
BatchKLineResult,
|
||||
BatchKLineData,
|
||||
KLineSubData,
|
||||
Symbol,
|
||||
SymbolListRequest,
|
||||
SymbolListData,
|
||||
DataSourceInfo,
|
||||
DataSourceStatusData,
|
||||
SourceSwitchRequest,
|
||||
BackfillRequest,
|
||||
TradingDatesRequest,
|
||||
TradingDatesData,
|
||||
FuturesContractsRequest,
|
||||
FuturesContractInfo,
|
||||
FuturesContractsData,
|
||||
Response,
|
||||
ErrorResponse,
|
||||
SuccessResponse,
|
||||
HealthResponse,
|
||||
)
|
||||
from .admin_types import (
|
||||
ConfigType,
|
||||
ConfigItem,
|
||||
ConfigSection,
|
||||
ConfigListRequest,
|
||||
ConfigListData,
|
||||
ConfigUpdateRequest,
|
||||
ConfigUpdateData,
|
||||
AdapterInfo,
|
||||
AdapterStatus,
|
||||
AdapterListData,
|
||||
AdapterToggleRequest,
|
||||
AdapterConfigUpdateRequest,
|
||||
SystemStatusData,
|
||||
MemoryInfo,
|
||||
RestartRequest,
|
||||
ReloadRequest,
|
||||
ReloadData,
|
||||
APITestCase,
|
||||
APITestCategory,
|
||||
APITestListData,
|
||||
APITestRequest,
|
||||
APITestResult,
|
||||
WSTestCase,
|
||||
WSTestListData,
|
||||
WSTestRequest,
|
||||
WSTestResult,
|
||||
WSMessage,
|
||||
TestHistoryRequest,
|
||||
TestHistoryData,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 基础类型
|
||||
"Frequency",
|
||||
"AdjustType",
|
||||
"AssetClass",
|
||||
"SymbolType",
|
||||
"Exchange",
|
||||
"DataSourceStatus",
|
||||
# K线数据
|
||||
"KLineItem",
|
||||
"KLineData",
|
||||
"KLineQueryRequest",
|
||||
"BatchKLineRequest",
|
||||
"BatchKLineResult",
|
||||
"BatchKLineData",
|
||||
"KLineSubData",
|
||||
# 标的
|
||||
"Symbol",
|
||||
"SymbolListRequest",
|
||||
"SymbolListData",
|
||||
# 数据源
|
||||
"DataSourceInfo",
|
||||
"DataSourceStatusData",
|
||||
"SourceSwitchRequest",
|
||||
"BackfillRequest",
|
||||
# 交易日历
|
||||
"TradingDatesRequest",
|
||||
"TradingDatesData",
|
||||
# 期货
|
||||
"FuturesContractsRequest",
|
||||
"FuturesContractInfo",
|
||||
"FuturesContractsData",
|
||||
# 响应
|
||||
"Response",
|
||||
"ErrorResponse",
|
||||
"SuccessResponse",
|
||||
"HealthResponse",
|
||||
# 管理后台
|
||||
"ConfigType",
|
||||
"ConfigItem",
|
||||
"ConfigSection",
|
||||
"ConfigListRequest",
|
||||
"ConfigListData",
|
||||
"ConfigUpdateRequest",
|
||||
"ConfigUpdateData",
|
||||
"AdapterInfo",
|
||||
"AdapterStatus",
|
||||
"AdapterListData",
|
||||
"AdapterToggleRequest",
|
||||
"AdapterConfigUpdateRequest",
|
||||
"SystemStatusData",
|
||||
"MemoryInfo",
|
||||
"RestartRequest",
|
||||
"ReloadRequest",
|
||||
"ReloadData",
|
||||
"APITestCase",
|
||||
"APITestCategory",
|
||||
"APITestListData",
|
||||
"APITestRequest",
|
||||
"APITestResult",
|
||||
"WSTestCase",
|
||||
"WSTestListData",
|
||||
"WSTestRequest",
|
||||
"WSTestResult",
|
||||
"WSMessage",
|
||||
"TestHistoryRequest",
|
||||
"TestHistoryData",
|
||||
]
|
||||
@ -0,0 +1,250 @@
|
||||
"""管理后台类型定义 - 对应Go的api/admin_types.go"""
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================
|
||||
# 配置管理类型
|
||||
# ============================================
|
||||
|
||||
class ConfigType(str, Enum):
|
||||
"""配置类型"""
|
||||
SERVER = "server"
|
||||
DATABASE = "database"
|
||||
REDIS = "redis"
|
||||
SOURCE = "source"
|
||||
MONITOR = "monitor"
|
||||
LOG = "log"
|
||||
|
||||
|
||||
class ConfigItem(BaseModel):
|
||||
"""配置项"""
|
||||
key: str = Field(..., description="配置键")
|
||||
value: Any = Field(..., description="配置值")
|
||||
type: str = Field(..., description="值类型: string/int/bool/json")
|
||||
description: str = Field(..., description="配置说明")
|
||||
editable: bool = Field(default=True, description="是否可编辑")
|
||||
required: bool = Field(default=True, description="是否必填")
|
||||
|
||||
|
||||
class ConfigSection(BaseModel):
|
||||
"""配置分组"""
|
||||
name: str = Field(..., description="分组名称")
|
||||
type: ConfigType = Field(..., description="分组类型")
|
||||
description: str = Field(..., description="分组说明")
|
||||
items: List[ConfigItem] = Field(default_factory=list, description="配置项列表")
|
||||
|
||||
|
||||
class ConfigListRequest(BaseModel):
|
||||
"""获取配置列表请求"""
|
||||
type: Optional[ConfigType] = Field(None, description="配置类型筛选")
|
||||
|
||||
|
||||
class ConfigListData(BaseModel):
|
||||
"""配置列表响应"""
|
||||
sections: List[ConfigSection] = Field(default_factory=list, description="配置分组列表")
|
||||
version: str = Field(default="1.0.0", description="配置版本")
|
||||
updated: datetime = Field(default_factory=datetime.now, description="最后更新时间")
|
||||
|
||||
|
||||
class ConfigUpdateRequest(BaseModel):
|
||||
"""更新配置请求"""
|
||||
type: ConfigType = Field(..., description="配置类型")
|
||||
items: Dict[str, Any] = Field(..., description="更新的配置项")
|
||||
|
||||
|
||||
class ConfigUpdateData(BaseModel):
|
||||
"""更新配置响应"""
|
||||
success: bool = Field(..., description="是否成功")
|
||||
need_restart: bool = Field(default=False, description="是否需要重启")
|
||||
message: str = Field(..., description="提示信息")
|
||||
|
||||
|
||||
# ============================================
|
||||
# 适配器管理类型
|
||||
# ============================================
|
||||
|
||||
class AdapterStatus(str, Enum):
|
||||
"""适配器状态"""
|
||||
ACTIVE = "active"
|
||||
STANDBY = "standby"
|
||||
DISABLED = "disabled"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class AdapterInfo(BaseModel):
|
||||
"""适配器信息"""
|
||||
name: str = Field(..., description="适配器名称")
|
||||
type: str = Field(..., description="适配器类型")
|
||||
version: str = Field(..., description="版本")
|
||||
description: str = Field(..., description="描述")
|
||||
status: AdapterStatus = Field(..., description="状态")
|
||||
config: Dict[str, str] = Field(default_factory=dict, description="当前配置")
|
||||
last_error: Optional[str] = Field(None, description="最后错误")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
|
||||
|
||||
|
||||
class AdapterListData(BaseModel):
|
||||
"""适配器列表响应"""
|
||||
adapters: List[AdapterInfo] = Field(default_factory=list, description="适配器列表")
|
||||
|
||||
|
||||
class AdapterToggleRequest(BaseModel):
|
||||
"""启用/禁用适配器请求"""
|
||||
name: str = Field(..., description="适配器名称")
|
||||
enable: bool = Field(..., description="是否启用")
|
||||
|
||||
|
||||
class AdapterConfigUpdateRequest(BaseModel):
|
||||
"""更新适配器配置请求"""
|
||||
name: str = Field(..., description="适配器名称")
|
||||
config: Dict[str, str] = Field(..., description="配置")
|
||||
|
||||
|
||||
# ============================================
|
||||
# 系统管理类型
|
||||
# ============================================
|
||||
|
||||
class MemoryInfo(BaseModel):
|
||||
"""内存信息"""
|
||||
alloc: int = Field(..., description="已分配内存")
|
||||
total_alloc: int = Field(..., description="累计分配")
|
||||
sys: int = Field(..., description="系统内存")
|
||||
num_gc: int = Field(..., description="GC次数")
|
||||
|
||||
|
||||
class SystemStatusData(BaseModel):
|
||||
"""系统状态数据"""
|
||||
status: str = Field(..., description="系统状态")
|
||||
version: str = Field(..., description="系统版本")
|
||||
start_time: datetime = Field(..., description="启动时间")
|
||||
uptime: str = Field(..., description="运行时长")
|
||||
python_version: str = Field(..., description="Python版本")
|
||||
memory: MemoryInfo = Field(..., description="内存使用")
|
||||
threads: int = Field(..., description="线程数量")
|
||||
|
||||
|
||||
class RestartRequest(BaseModel):
|
||||
"""重启服务请求"""
|
||||
force: bool = Field(default=False, description="是否强制重启")
|
||||
|
||||
|
||||
class ReloadRequest(BaseModel):
|
||||
"""热加载配置请求"""
|
||||
config_type: Optional[ConfigType] = Field(None, description="指定配置类型")
|
||||
|
||||
|
||||
class ReloadData(BaseModel):
|
||||
"""热加载响应"""
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="提示信息")
|
||||
|
||||
|
||||
# ============================================
|
||||
# 接口测试类型
|
||||
# ============================================
|
||||
|
||||
class APITestCase(BaseModel):
|
||||
"""接口测试用例"""
|
||||
id: str = Field(..., description="用例ID")
|
||||
name: str = Field(..., description="用例名称")
|
||||
method: str = Field(..., description="HTTP方法")
|
||||
path: str = Field(..., description="请求路径")
|
||||
description: str = Field(..., description="描述")
|
||||
params: Dict[str, str] = Field(default_factory=dict, description="默认参数")
|
||||
body: Optional[Any] = Field(None, description="请求体")
|
||||
|
||||
|
||||
class APITestCategory(BaseModel):
|
||||
"""测试分类"""
|
||||
name: str = Field(..., description="分类名称")
|
||||
items: List[APITestCase] = Field(default_factory=list, description="测试用例")
|
||||
|
||||
|
||||
class APITestListData(BaseModel):
|
||||
"""接口测试列表响应"""
|
||||
categories: List[APITestCategory] = Field(default_factory=list, description="分类列表")
|
||||
base_url: str = Field(default="", description="基础URL")
|
||||
|
||||
|
||||
class APITestRequest(BaseModel):
|
||||
"""执行接口测试请求"""
|
||||
id: str = Field(..., description="用例ID")
|
||||
params: Dict[str, str] = Field(default_factory=dict, description="自定义参数")
|
||||
body: Optional[Any] = Field(None, description="自定义请求体")
|
||||
|
||||
|
||||
class APITestResult(BaseModel):
|
||||
"""接口测试结果"""
|
||||
id: int = Field(..., description="测试ID")
|
||||
case_id: str = Field(..., description="用例ID")
|
||||
name: str = Field(..., description="用例名称")
|
||||
success: bool = Field(..., description="是否成功")
|
||||
status_code: int = Field(0, description="HTTP状态码")
|
||||
latency: int = Field(..., description="延迟(ms)")
|
||||
request: Any = Field(None, description="请求信息")
|
||||
response: Any = Field(None, description="响应信息")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="测试时间")
|
||||
|
||||
|
||||
# ============================================
|
||||
# WebSocket测试类型
|
||||
# ============================================
|
||||
|
||||
class WSTestCase(BaseModel):
|
||||
"""WebSocket测试用例"""
|
||||
id: str = Field(..., description="用例ID")
|
||||
name: str = Field(..., description="用例名称")
|
||||
description: str = Field(..., description="描述")
|
||||
action: str = Field(..., description="动作类型")
|
||||
symbols: List[str] = Field(default_factory=list, description="订阅标的")
|
||||
|
||||
|
||||
class WSTestListData(BaseModel):
|
||||
"""WebSocket测试列表响应"""
|
||||
cases: List[WSTestCase] = Field(default_factory=list, description="测试用例")
|
||||
ws_url: str = Field(default="", description="WebSocket地址")
|
||||
|
||||
|
||||
class WSTestRequest(BaseModel):
|
||||
"""WebSocket测试请求"""
|
||||
id: str = Field(..., description="用例ID")
|
||||
symbols: List[str] = Field(default_factory=list, description="自定义标的")
|
||||
|
||||
|
||||
class WSMessage(BaseModel):
|
||||
"""WebSocket消息"""
|
||||
type: str = Field(..., description="消息类型")
|
||||
data: Any = Field(None, description="消息内容")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="时间")
|
||||
|
||||
|
||||
class WSTestResult(BaseModel):
|
||||
"""WebSocket测试结果"""
|
||||
id: str = Field(..., description="测试ID")
|
||||
case_id: str = Field(..., description="用例ID")
|
||||
success: bool = Field(..., description="是否成功")
|
||||
latency: int = Field(..., description="连接延迟(ms)")
|
||||
messages: List[WSMessage] = Field(default_factory=list, description="收到的消息")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="测试时间")
|
||||
|
||||
|
||||
# ============================================
|
||||
# 测试历史记录类型
|
||||
# ============================================
|
||||
|
||||
class TestHistoryRequest(BaseModel):
|
||||
"""获取测试历史请求"""
|
||||
type: Optional[str] = Field(None, description="测试类型: api/ws")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="数量限制")
|
||||
|
||||
|
||||
class TestHistoryData(BaseModel):
|
||||
"""测试历史数据"""
|
||||
api_tests: List[APITestResult] = Field(default_factory=list, description="API测试历史")
|
||||
ws_tests: List[WSTestResult] = Field(default_factory=list, description="WebSocket测试历史")
|
||||
@ -0,0 +1,4 @@
|
||||
"""数据质量监控模块"""
|
||||
from .monitor import DataQualityMonitor, AlertSender, LogAlertSender
|
||||
|
||||
__all__ = ["DataQualityMonitor", "AlertSender", "LogAlertSender"]
|
||||
@ -0,0 +1,13 @@
|
||||
"""数据访问层模块"""
|
||||
from .database import get_db, SessionLocal, engine, Base
|
||||
from .stock_repository import StockRepository
|
||||
from .futures_repository import FuturesRepository
|
||||
|
||||
__all__ = [
|
||||
"get_db",
|
||||
"SessionLocal",
|
||||
"engine",
|
||||
"Base",
|
||||
"StockRepository",
|
||||
"FuturesRepository",
|
||||
]
|
||||
@ -0,0 +1,268 @@
|
||||
"""期货数据仓库"""
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
from app.models import (
|
||||
KLineItem, Symbol, SymbolListRequest, SymbolListData,
|
||||
TradingDatesData, TradeCalData, Frequency,
|
||||
FuturesContractsData, FuturesContractInfo
|
||||
)
|
||||
from app.repositories.models import (
|
||||
FuturesSymbol, FuturesKLine1M, FuturesKLine1D,
|
||||
FuturesTradingCalendar
|
||||
)
|
||||
|
||||
|
||||
class FuturesRepository:
|
||||
"""期货数据仓库"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_klines(
|
||||
self,
|
||||
symbol: str,
|
||||
freq: Frequency,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
) -> List[KLineItem]:
|
||||
"""获取K线数据"""
|
||||
kline_model = self._get_kline_model(freq)
|
||||
|
||||
query = self.db.query(kline_model).filter(
|
||||
kline_model.symbol_id == symbol,
|
||||
kline_model.ts >= start,
|
||||
kline_model.ts <= end
|
||||
).order_by(kline_model.ts.asc())
|
||||
|
||||
results = query.all()
|
||||
|
||||
items = []
|
||||
for r in results:
|
||||
item = KLineItem(
|
||||
time=r.ts,
|
||||
open=float(r.open),
|
||||
high=float(r.high),
|
||||
low=float(r.low),
|
||||
close=float(r.close),
|
||||
volume=r.volume,
|
||||
amount=float(r.amount),
|
||||
open_interest=r.open_interest
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def _get_kline_model(self, freq: Frequency):
|
||||
"""根据周期获取K线模型"""
|
||||
mapping = {
|
||||
Frequency.FREQ_1M: FuturesKLine1M,
|
||||
Frequency.FREQ_1D: FuturesKLine1D,
|
||||
}
|
||||
return mapping.get(freq, FuturesKLine1D)
|
||||
|
||||
def save_klines(
|
||||
self,
|
||||
freq: Frequency,
|
||||
symbol: str,
|
||||
items: List[KLineItem]
|
||||
) -> None:
|
||||
"""保存K线数据"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
kline_model = self._get_kline_model(freq)
|
||||
|
||||
for item in items:
|
||||
existing = self.db.query(kline_model).filter(
|
||||
kline_model.symbol_id == symbol,
|
||||
kline_model.ts == item.time
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.open = item.open
|
||||
existing.high = item.high
|
||||
existing.low = item.low
|
||||
existing.close = item.close
|
||||
existing.volume = item.volume
|
||||
existing.amount = item.amount
|
||||
existing.open_interest = item.open_interest
|
||||
else:
|
||||
new_record = kline_model(
|
||||
symbol_id=symbol,
|
||||
ts=item.time,
|
||||
open=item.open,
|
||||
high=item.high,
|
||||
low=item.low,
|
||||
close=item.close,
|
||||
volume=item.volume,
|
||||
amount=item.amount,
|
||||
open_interest=item.open_interest
|
||||
)
|
||||
self.db.add(new_record)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def list_symbols(
|
||||
self,
|
||||
req: SymbolListRequest
|
||||
) -> Tuple[List[Symbol], int]:
|
||||
"""查询标的列表"""
|
||||
query = self.db.query(FuturesSymbol)
|
||||
|
||||
# 筛选条件
|
||||
if req.exchange:
|
||||
query = query.filter(FuturesSymbol.exchange == req.exchange.value)
|
||||
|
||||
if req.underlying:
|
||||
query = query.filter(FuturesSymbol.underlying == req.underlying)
|
||||
|
||||
if req.keyword:
|
||||
keyword = f"%{req.keyword}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
FuturesSymbol.symbol_id.ilike(keyword),
|
||||
FuturesSymbol.name.ilike(keyword)
|
||||
)
|
||||
)
|
||||
|
||||
# 查询总数
|
||||
total = query.count()
|
||||
|
||||
# 分页查询
|
||||
results = query.order_by(FuturesSymbol.symbol_id).offset(
|
||||
(req.page - 1) * req.size
|
||||
).limit(req.size).all()
|
||||
|
||||
symbols = []
|
||||
for r in results:
|
||||
s = Symbol(
|
||||
symbol_id=r.symbol_id,
|
||||
symbol_type=r.symbol_type,
|
||||
exchange=r.exchange,
|
||||
name=r.name,
|
||||
underlying=r.underlying,
|
||||
contract_month=r.contract_month,
|
||||
list_date=r.list_date,
|
||||
delist_date=r.delist_date,
|
||||
status=r.status
|
||||
)
|
||||
symbols.append(s)
|
||||
|
||||
return symbols, total
|
||||
|
||||
def get_trading_dates(self, start: str, end: str) -> TradingDatesData:
|
||||
"""获取交易日历"""
|
||||
results = self.db.query(FuturesTradingCalendar).filter(
|
||||
FuturesTradingCalendar.trade_date >= start,
|
||||
FuturesTradingCalendar.trade_date <= end,
|
||||
FuturesTradingCalendar.is_trading_day == True
|
||||
).order_by(FuturesTradingCalendar.trade_date.asc()).all()
|
||||
|
||||
dates = [r.trade_date for r in results]
|
||||
|
||||
# 计算总天数
|
||||
start_date = datetime.strptime(start, "%Y%m%d")
|
||||
end_date = datetime.strptime(end, "%Y%m%d")
|
||||
total_days = (end_date - start_date).days + 1
|
||||
|
||||
return TradingDatesData(
|
||||
start=start,
|
||||
end=end,
|
||||
total_days=total_days,
|
||||
trading_days=len(dates),
|
||||
trading_dates=dates
|
||||
)
|
||||
|
||||
def get_contracts_by_underlying(
|
||||
self,
|
||||
underlying: str,
|
||||
exchange: Optional[str] = None
|
||||
) -> FuturesContractsData:
|
||||
"""根据品种获取合约"""
|
||||
query = self.db.query(FuturesSymbol).filter(
|
||||
FuturesSymbol.underlying == underlying,
|
||||
FuturesSymbol.status == "active"
|
||||
)
|
||||
|
||||
if exchange:
|
||||
query = query.filter(FuturesSymbol.exchange == exchange)
|
||||
|
||||
results = query.order_by(FuturesSymbol.contract_month.asc()).all()
|
||||
|
||||
contracts = []
|
||||
for r in results:
|
||||
c = FuturesContractInfo(
|
||||
symbol_id=r.symbol_id,
|
||||
exchange=r.exchange,
|
||||
name=r.name,
|
||||
underlying=r.underlying,
|
||||
contract_month=r.contract_month,
|
||||
list_date=r.list_date,
|
||||
delist_date=r.delist_date,
|
||||
status=r.status
|
||||
)
|
||||
contracts.append(c)
|
||||
|
||||
return FuturesContractsData(
|
||||
underlying=underlying,
|
||||
count=len(contracts),
|
||||
items=contracts
|
||||
)
|
||||
|
||||
def save_symbols(self, symbols: List[Symbol]) -> None:
|
||||
"""保存标的列表"""
|
||||
for s in symbols:
|
||||
existing = self.db.query(FuturesSymbol).filter(
|
||||
FuturesSymbol.symbol_id == s.symbol_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.name = s.name
|
||||
existing.underlying = s.underlying
|
||||
existing.contract_month = s.contract_month
|
||||
existing.list_date = s.list_date
|
||||
existing.delist_date = s.delist_date
|
||||
existing.status = s.status
|
||||
else:
|
||||
new_symbol = FuturesSymbol(
|
||||
symbol_id=s.symbol_id,
|
||||
symbol_type=s.symbol_type.value if s.symbol_type else "futures",
|
||||
exchange=s.exchange.value if s.exchange else "",
|
||||
name=s.name,
|
||||
underlying=s.underlying or "",
|
||||
contract_month=s.contract_month or "",
|
||||
list_date=s.list_date,
|
||||
delist_date=s.delist_date,
|
||||
status=s.status
|
||||
)
|
||||
self.db.add(new_symbol)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def save_trading_calendar(self, dates: List[TradeCalData]) -> None:
|
||||
"""保存交易日历"""
|
||||
for d in dates:
|
||||
date_str = d.date.strftime("%Y%m%d")
|
||||
|
||||
existing = self.db.query(FuturesTradingCalendar).filter(
|
||||
FuturesTradingCalendar.trade_date == date_str
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.is_trading_day = d.is_trading_day
|
||||
existing.has_night_session = d.has_night_session
|
||||
existing.week_day = d.date.weekday() + 1
|
||||
else:
|
||||
new_cal = FuturesTradingCalendar(
|
||||
trade_date=date_str,
|
||||
is_trading_day=d.is_trading_day,
|
||||
has_night_session=d.has_night_session,
|
||||
week_day=d.date.weekday() + 1
|
||||
)
|
||||
self.db.add(new_cal)
|
||||
|
||||
self.db.commit()
|
||||
@ -0,0 +1,214 @@
|
||||
"""数据库模型定义"""
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Float, DateTime,
|
||||
Boolean, Numeric, BigInteger, Index
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from app.repositories.database import Base
|
||||
|
||||
|
||||
# ============================================
|
||||
# 股票相关表
|
||||
# ============================================
|
||||
|
||||
class StockSymbol(Base):
|
||||
"""股票标的表"""
|
||||
__tablename__ = "symbols"
|
||||
__table_args__ = {"schema": "stock"}
|
||||
|
||||
symbol_id = Column(String(20), primary_key=True, index=True, comment="标的代码")
|
||||
symbol_type = Column(String(20), nullable=False, comment="标的类型")
|
||||
exchange = Column(String(10), nullable=False, index=True, comment="交易所")
|
||||
name = Column(String(100), nullable=False, comment="名称")
|
||||
name_en = Column(String(100), nullable=True, comment="英文名称")
|
||||
list_date = Column(DateTime, nullable=True, comment="上市日期")
|
||||
delist_date = Column(DateTime, nullable=True, comment="退市日期")
|
||||
industry = Column(String(50), nullable=True, comment="行业分类")
|
||||
status = Column(String(20), nullable=False, default="active", comment="状态")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间")
|
||||
|
||||
|
||||
class StockTradingCalendar(Base):
|
||||
"""股票交易日历表"""
|
||||
__tablename__ = "trading_calendar"
|
||||
__table_args__ = {"schema": "stock"}
|
||||
|
||||
trade_date = Column(String(8), primary_key=True, comment="交易日期")
|
||||
is_trading_day = Column(Boolean, nullable=False, comment="是否交易日")
|
||||
week_day = Column(Integer, nullable=True, comment="星期几")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间")
|
||||
|
||||
|
||||
class StockKLine1M(Base):
|
||||
"""股票1分钟K线"""
|
||||
__tablename__ = "klines_1m"
|
||||
__table_args__ = (
|
||||
Index("idx_stock_1m_symbol_ts", "symbol_id", "ts"),
|
||||
{"schema": "stock"}
|
||||
)
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
symbol_id = Column(String(20), nullable=False, index=True, comment="标的代码")
|
||||
ts = Column(DateTime, nullable=False, comment="时间戳")
|
||||
open = Column(Numeric(18, 4), nullable=False, comment="开盘价")
|
||||
high = Column(Numeric(18, 4), nullable=False, comment="最高价")
|
||||
low = Column(Numeric(18, 4), nullable=False, comment="最低价")
|
||||
close = Column(Numeric(18, 4), nullable=False, comment="收盘价")
|
||||
volume = Column(BigInteger, nullable=False, comment="成交量")
|
||||
amount = Column(Numeric(20, 4), nullable=False, comment="成交额")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
|
||||
|
||||
class StockKLine5M(Base):
|
||||
"""股票5分钟K线"""
|
||||
__tablename__ = "klines_5m"
|
||||
__table_args__ = (
|
||||
Index("idx_stock_5m_symbol_ts", "symbol_id", "ts"),
|
||||
{"schema": "stock"}
|
||||
)
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
symbol_id = Column(String(20), nullable=False, index=True, comment="标的代码")
|
||||
ts = Column(DateTime, nullable=False, comment="时间戳")
|
||||
open = Column(Numeric(18, 4), nullable=False, comment="开盘价")
|
||||
high = Column(Numeric(18, 4), nullable=False, comment="最高价")
|
||||
low = Column(Numeric(18, 4), nullable=False, comment="最低价")
|
||||
close = Column(Numeric(18, 4), nullable=False, comment="收盘价")
|
||||
volume = Column(BigInteger, nullable=False, comment="成交量")
|
||||
amount = Column(Numeric(20, 4), nullable=False, comment="成交额")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
|
||||
|
||||
class StockKLine1D(Base):
|
||||
"""股票日线K线"""
|
||||
__tablename__ = "klines_1d"
|
||||
__table_args__ = (
|
||||
Index("idx_stock_1d_symbol_ts", "symbol_id", "ts"),
|
||||
{"schema": "stock"}
|
||||
)
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
symbol_id = Column(String(20), nullable=False, index=True, comment="标的代码")
|
||||
ts = Column(DateTime, nullable=False, comment="时间戳")
|
||||
open = Column(Numeric(18, 4), nullable=False, comment="开盘价")
|
||||
high = Column(Numeric(18, 4), nullable=False, comment="最高价")
|
||||
low = Column(Numeric(18, 4), nullable=False, comment="最低价")
|
||||
close = Column(Numeric(18, 4), nullable=False, comment="收盘价")
|
||||
volume = Column(BigInteger, nullable=False, comment="成交量")
|
||||
amount = Column(Numeric(20, 4), nullable=False, comment="成交额")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
|
||||
|
||||
# ============================================
|
||||
# 期货相关表
|
||||
# ============================================
|
||||
|
||||
class FuturesSymbol(Base):
|
||||
"""期货合约表"""
|
||||
__tablename__ = "symbols"
|
||||
__table_args__ = {"schema": "futures"}
|
||||
|
||||
symbol_id = Column(String(20), primary_key=True, index=True, comment="合约代码")
|
||||
symbol_type = Column(String(20), nullable=False, comment="标的类型")
|
||||
exchange = Column(String(10), nullable=False, index=True, comment="交易所")
|
||||
name = Column(String(100), nullable=False, comment="名称")
|
||||
underlying = Column(String(10), nullable=False, index=True, comment="品种代码")
|
||||
contract_month = Column(String(6), nullable=False, comment="合约月份")
|
||||
list_date = Column(DateTime, nullable=True, comment="上市日期")
|
||||
delist_date = Column(DateTime, nullable=True, comment="退市日期")
|
||||
status = Column(String(20), nullable=False, default="active", comment="状态")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间")
|
||||
|
||||
|
||||
class FuturesTradingCalendar(Base):
|
||||
"""期货交易日历表"""
|
||||
__tablename__ = "trading_calendar"
|
||||
__table_args__ = {"schema": "futures"}
|
||||
|
||||
trade_date = Column(String(8), primary_key=True, comment="交易日期")
|
||||
is_trading_day = Column(Boolean, nullable=False, comment="是否交易日")
|
||||
has_night_session = Column(Boolean, default=False, comment="是否有夜盘")
|
||||
week_day = Column(Integer, nullable=True, comment="星期几")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间")
|
||||
|
||||
|
||||
class FuturesKLine1M(Base):
|
||||
"""期货1分钟K线"""
|
||||
__tablename__ = "klines_1m"
|
||||
__table_args__ = (
|
||||
Index("idx_futures_1m_symbol_ts", "symbol_id", "ts"),
|
||||
{"schema": "futures"}
|
||||
)
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
symbol_id = Column(String(20), nullable=False, index=True, comment="合约代码")
|
||||
ts = Column(DateTime, nullable=False, comment="时间戳")
|
||||
open = Column(Numeric(18, 4), nullable=False, comment="开盘价")
|
||||
high = Column(Numeric(18, 4), nullable=False, comment="最高价")
|
||||
low = Column(Numeric(18, 4), nullable=False, comment="最低价")
|
||||
close = Column(Numeric(18, 4), nullable=False, comment="收盘价")
|
||||
volume = Column(BigInteger, nullable=False, comment="成交量")
|
||||
amount = Column(Numeric(20, 4), nullable=False, comment="成交额")
|
||||
open_interest = Column(BigInteger, nullable=True, comment="持仓量")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
|
||||
|
||||
class FuturesKLine1D(Base):
|
||||
"""期货日线K线"""
|
||||
__tablename__ = "klines_1d"
|
||||
__table_args__ = (
|
||||
Index("idx_futures_1d_symbol_ts", "symbol_id", "ts"),
|
||||
{"schema": "futures"}
|
||||
)
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
symbol_id = Column(String(20), nullable=False, index=True, comment="合约代码")
|
||||
ts = Column(DateTime, nullable=False, comment="时间戳")
|
||||
open = Column(Numeric(18, 4), nullable=False, comment="开盘价")
|
||||
high = Column(Numeric(18, 4), nullable=False, comment="最高价")
|
||||
low = Column(Numeric(18, 4), nullable=False, comment="最低价")
|
||||
close = Column(Numeric(18, 4), nullable=False, comment="收盘价")
|
||||
volume = Column(BigInteger, nullable=False, comment="成交量")
|
||||
amount = Column(Numeric(20, 4), nullable=False, comment="成交额")
|
||||
open_interest = Column(BigInteger, nullable=True, comment="持仓量")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
|
||||
|
||||
# ============================================
|
||||
# 公共表
|
||||
# ============================================
|
||||
|
||||
class DataSourceConfig(Base):
|
||||
"""数据源配置表"""
|
||||
__tablename__ = "data_source_config"
|
||||
__table_args__ = {"schema": "public"}
|
||||
|
||||
asset_class = Column(String(20), primary_key=True, comment="资产类别")
|
||||
active_source = Column(String(50), nullable=False, comment="当前激活源")
|
||||
standby_sources = Column(ARRAY(String), nullable=True, comment="待命源列表")
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间")
|
||||
|
||||
|
||||
class DataQualityCheck(Base):
|
||||
"""数据质量检查表"""
|
||||
__tablename__ = "data_quality_checks"
|
||||
__table_args__ = {"schema": "stock"} # 也可以是futures
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
check_date = Column(String(8), nullable=False, index=True, comment="检查日期")
|
||||
symbol_id = Column(String(20), nullable=False, index=True, comment="标的代码")
|
||||
freq = Column(String(10), nullable=False, comment="周期")
|
||||
check_type = Column(String(20), nullable=False, comment="检查类型")
|
||||
status = Column(String(10), nullable=False, comment="状态 pass/fail")
|
||||
expect_count = Column(Integer, nullable=True, comment="期望数量")
|
||||
actual_count = Column(Integer, nullable=True, comment="实际数量")
|
||||
detail = Column(String(500), nullable=True, comment="详情")
|
||||
created_at = Column(DateTime, default=datetime.now, comment="创建时间")
|
||||
@ -0,0 +1,222 @@
|
||||
"""股票数据仓库"""
|
||||
from datetime import datetime, time
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, or_
|
||||
|
||||
from app.models import (
|
||||
KLineItem, Symbol, SymbolListRequest, SymbolListData,
|
||||
TradingDatesData, TradeCalData, AdjustType, Frequency
|
||||
)
|
||||
from app.repositories.models import (
|
||||
StockSymbol, StockKLine1M, StockKLine5M, StockKLine1D,
|
||||
StockTradingCalendar
|
||||
)
|
||||
|
||||
|
||||
class StockRepository:
|
||||
"""股票数据仓库"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_klines(
|
||||
self,
|
||||
symbol: str,
|
||||
freq: Frequency,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
adjust: AdjustType = AdjustType.NONE
|
||||
) -> List[KLineItem]:
|
||||
"""获取K线数据"""
|
||||
# 根据周期选择表
|
||||
kline_model = self._get_kline_model(freq)
|
||||
|
||||
query = self.db.query(kline_model).filter(
|
||||
kline_model.symbol_id == symbol,
|
||||
kline_model.ts >= start,
|
||||
kline_model.ts <= end
|
||||
).order_by(kline_model.ts.asc())
|
||||
|
||||
results = query.all()
|
||||
|
||||
items = []
|
||||
for r in results:
|
||||
item = KLineItem(
|
||||
time=r.ts,
|
||||
open=float(r.open),
|
||||
high=float(r.high),
|
||||
low=float(r.low),
|
||||
close=float(r.close),
|
||||
volume=r.volume,
|
||||
amount=float(r.amount)
|
||||
)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
def _get_kline_model(self, freq: Frequency):
|
||||
"""根据周期获取K线模型"""
|
||||
mapping = {
|
||||
Frequency.FREQ_1M: StockKLine1M,
|
||||
Frequency.FREQ_5M: StockKLine5M,
|
||||
Frequency.FREQ_1D: StockKLine1D,
|
||||
}
|
||||
return mapping.get(freq, StockKLine1D)
|
||||
|
||||
def save_klines(self, freq: Frequency, items: List[KLineItem]) -> None:
|
||||
"""保存K线数据"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
kline_model = self._get_kline_model(freq)
|
||||
|
||||
for item in items:
|
||||
# 使用upsert逻辑
|
||||
existing = self.db.query(kline_model).filter(
|
||||
kline_model.symbol_id == getattr(item, 'symbol', ''),
|
||||
kline_model.ts == item.time
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.open = item.open
|
||||
existing.high = item.high
|
||||
existing.low = item.low
|
||||
existing.close = item.close
|
||||
existing.volume = item.volume
|
||||
existing.amount = item.amount
|
||||
else:
|
||||
new_record = kline_model(
|
||||
symbol_id=getattr(item, 'symbol', ''),
|
||||
ts=item.time,
|
||||
open=item.open,
|
||||
high=item.high,
|
||||
low=item.low,
|
||||
close=item.close,
|
||||
volume=item.volume,
|
||||
amount=item.amount
|
||||
)
|
||||
self.db.add(new_record)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def list_symbols(
|
||||
self,
|
||||
req: SymbolListRequest
|
||||
) -> Tuple[List[Symbol], int]:
|
||||
"""查询标的列表"""
|
||||
query = self.db.query(StockSymbol)
|
||||
|
||||
# 筛选条件
|
||||
if req.exchange:
|
||||
query = query.filter(StockSymbol.exchange == req.exchange.value)
|
||||
|
||||
if req.keyword:
|
||||
keyword = f"%{req.keyword}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
StockSymbol.symbol_id.ilike(keyword),
|
||||
StockSymbol.name.ilike(keyword)
|
||||
)
|
||||
)
|
||||
|
||||
# 查询总数
|
||||
total = query.count()
|
||||
|
||||
# 分页查询
|
||||
results = query.order_by(StockSymbol.symbol_id).offset(
|
||||
(req.page - 1) * req.size
|
||||
).limit(req.size).all()
|
||||
|
||||
symbols = []
|
||||
for r in results:
|
||||
s = Symbol(
|
||||
symbol_id=r.symbol_id,
|
||||
symbol_type=r.symbol_type,
|
||||
exchange=r.exchange,
|
||||
name=r.name,
|
||||
name_en=r.name_en,
|
||||
list_date=r.list_date,
|
||||
delist_date=r.delist_date,
|
||||
industry=r.industry,
|
||||
status=r.status
|
||||
)
|
||||
symbols.append(s)
|
||||
|
||||
return symbols, total
|
||||
|
||||
def get_trading_dates(self, start: str, end: str) -> TradingDatesData:
|
||||
"""获取交易日历"""
|
||||
results = self.db.query(StockTradingCalendar).filter(
|
||||
StockTradingCalendar.trade_date >= start,
|
||||
StockTradingCalendar.trade_date <= end,
|
||||
StockTradingCalendar.is_trading_day == True
|
||||
).order_by(StockTradingCalendar.trade_date.asc()).all()
|
||||
|
||||
dates = [r.trade_date for r in results]
|
||||
|
||||
# 计算总天数
|
||||
start_date = datetime.strptime(start, "%Y%m%d")
|
||||
end_date = datetime.strptime(end, "%Y%m%d")
|
||||
total_days = (end_date - start_date).days + 1
|
||||
|
||||
return TradingDatesData(
|
||||
start=start,
|
||||
end=end,
|
||||
total_days=total_days,
|
||||
trading_days=len(dates),
|
||||
trading_dates=dates
|
||||
)
|
||||
|
||||
def save_symbols(self, symbols: List[Symbol]) -> None:
|
||||
"""保存标的列表"""
|
||||
for s in symbols:
|
||||
existing = self.db.query(StockSymbol).filter(
|
||||
StockSymbol.symbol_id == s.symbol_id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.name = s.name
|
||||
existing.name_en = s.name_en
|
||||
existing.list_date = s.list_date
|
||||
existing.delist_date = s.delist_date
|
||||
existing.industry = s.industry
|
||||
existing.status = s.status
|
||||
else:
|
||||
new_symbol = StockSymbol(
|
||||
symbol_id=s.symbol_id,
|
||||
symbol_type=s.symbol_type.value if s.symbol_type else "stock",
|
||||
exchange=s.exchange.value if s.exchange else "",
|
||||
name=s.name,
|
||||
name_en=s.name_en,
|
||||
list_date=s.list_date,
|
||||
delist_date=s.delist_date,
|
||||
industry=s.industry,
|
||||
status=s.status
|
||||
)
|
||||
self.db.add(new_symbol)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def save_trading_calendar(self, dates: List[TradeCalData]) -> None:
|
||||
"""保存交易日历"""
|
||||
for d in dates:
|
||||
date_str = d.date.strftime("%Y%m%d")
|
||||
|
||||
existing = self.db.query(StockTradingCalendar).filter(
|
||||
StockTradingCalendar.trade_date == date_str
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
existing.is_trading_day = d.is_trading_day
|
||||
existing.week_day = d.date.weekday() + 1
|
||||
else:
|
||||
new_cal = StockTradingCalendar(
|
||||
trade_date=date_str,
|
||||
is_trading_day=d.is_trading_day,
|
||||
week_day=d.date.weekday() + 1
|
||||
)
|
||||
self.db.add(new_cal)
|
||||
|
||||
self.db.commit()
|
||||
@ -0,0 +1,16 @@
|
||||
"""业务服务层模块"""
|
||||
from .stock_service import StockService
|
||||
from .futures_service import FuturesService
|
||||
from .admin_service import AdminService
|
||||
from .config_service import ConfigService
|
||||
from .adapter_service import AdapterService
|
||||
from .test_service import TestService
|
||||
|
||||
__all__ = [
|
||||
"StockService",
|
||||
"FuturesService",
|
||||
"AdminService",
|
||||
"ConfigService",
|
||||
"AdapterService",
|
||||
"TestService",
|
||||
]
|
||||
@ -0,0 +1,194 @@
|
||||
"""适配器管理服务 - 对应Go的internal/service/adapter.go"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Callable
|
||||
from threading import RLock
|
||||
|
||||
from app.models import (
|
||||
AdapterListData, AdapterInfo, AdapterStatus,
|
||||
AdapterToggleRequest, AdapterConfigUpdateRequest
|
||||
)
|
||||
from app.adapters import DataSourceAdapter, TushareAdapter
|
||||
from app.core.logger import info, error
|
||||
|
||||
|
||||
class AdapterService:
|
||||
"""适配器管理服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.lock = RLock()
|
||||
|
||||
# 已注册的适配器工厂
|
||||
self.factories: Dict[str, Callable[[], DataSourceAdapter]] = {}
|
||||
|
||||
# 适配器配置
|
||||
self.configs: Dict[str, dict] = {}
|
||||
|
||||
# 当前激活的适配器实例
|
||||
self.active_adapters: Dict[str, DataSourceAdapter] = {}
|
||||
|
||||
# 适配器元数据
|
||||
self.metadata: Dict[str, dict] = {}
|
||||
|
||||
# 注册内置适配器
|
||||
self._register_builtin_adapters()
|
||||
|
||||
def _register_builtin_adapters(self):
|
||||
"""注册内置适配器"""
|
||||
# 注册Tushare适配器
|
||||
self.register_adapter("tushare", lambda: TushareAdapter())
|
||||
|
||||
# 设置Tushare元数据
|
||||
self.metadata["tushare"] = {
|
||||
"name": "tushare",
|
||||
"type": "http",
|
||||
"version": "1.0.0",
|
||||
"description": "Tushare Pro 金融数据接口",
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
# 默认配置
|
||||
self.configs["tushare"] = {
|
||||
"enabled": True,
|
||||
"config": {
|
||||
"token": "",
|
||||
"base_url": "https://api.tushare.pro"
|
||||
}
|
||||
}
|
||||
|
||||
# 预留Wind适配器
|
||||
self.metadata["wind"] = {
|
||||
"name": "wind",
|
||||
"type": "ws",
|
||||
"version": "1.0.0",
|
||||
"description": "Wind 金融终端接口(预留)",
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
self.configs["wind"] = {
|
||||
"enabled": False,
|
||||
"config": {
|
||||
"host": "localhost",
|
||||
"port": "8081"
|
||||
}
|
||||
}
|
||||
|
||||
def get_adapter_list(self) -> AdapterListData:
|
||||
"""获取适配器列表"""
|
||||
with self.lock:
|
||||
adapters = []
|
||||
|
||||
for name, meta in self.metadata.items():
|
||||
cfg = self.configs.get(name, {"enabled": False, "config": {}})
|
||||
|
||||
# 确定状态
|
||||
if not cfg["enabled"]:
|
||||
status = AdapterStatus.DISABLED
|
||||
elif name in self.active_adapters:
|
||||
status = AdapterStatus.ACTIVE
|
||||
else:
|
||||
status = AdapterStatus.STANDBY
|
||||
|
||||
adapters.append(AdapterInfo(
|
||||
name=meta["name"],
|
||||
type=meta["type"],
|
||||
version=meta["version"],
|
||||
description=meta["description"],
|
||||
status=status,
|
||||
config=cfg["config"],
|
||||
updated_at=meta["updated_at"]
|
||||
))
|
||||
|
||||
return AdapterListData(adapters=adapters)
|
||||
|
||||
def toggle_adapter(self, req: AdapterToggleRequest) -> None:
|
||||
"""启用/禁用适配器"""
|
||||
with self.lock:
|
||||
if req.name not in self.configs:
|
||||
raise ValueError(f"Adapter not found: {req.name}")
|
||||
|
||||
self.configs[req.name]["enabled"] = req.enable
|
||||
|
||||
# 如果禁用,关闭适配器连接
|
||||
if not req.enable and req.name in self.active_adapters:
|
||||
adapter = self.active_adapters.pop(req.name)
|
||||
asyncio.create_task(adapter.close())
|
||||
|
||||
# 更新元数据
|
||||
if req.name in self.metadata:
|
||||
self.metadata[req.name]["updated_at"] = datetime.now()
|
||||
|
||||
def update_adapter_config(self, req: AdapterConfigUpdateRequest) -> None:
|
||||
"""更新适配器配置"""
|
||||
with self.lock:
|
||||
if req.name not in self.configs:
|
||||
raise ValueError(f"Adapter not found: {req.name}")
|
||||
|
||||
# 更新配置
|
||||
self.configs[req.name]["config"].update(req.config)
|
||||
|
||||
# 如果适配器已激活,重新连接
|
||||
if req.name in self.active_adapters:
|
||||
adapter = self.active_adapters.pop(req.name)
|
||||
asyncio.create_task(adapter.close())
|
||||
|
||||
# 如果启用状态,重新连接
|
||||
if self.configs[req.name]["enabled"]:
|
||||
asyncio.create_task(self._connect_adapter(req.name))
|
||||
|
||||
# 更新元数据
|
||||
if req.name in self.metadata:
|
||||
self.metadata[req.name]["updated_at"] = datetime.now()
|
||||
|
||||
def get_active_adapter(self, asset_class: str) -> Optional[DataSourceAdapter]:
|
||||
"""获取当前激活的适配器"""
|
||||
with self.lock:
|
||||
# 根据资产类别获取配置(简化处理)
|
||||
adapter_name = "tushare"
|
||||
|
||||
# 检查是否已有激活的实例
|
||||
if adapter_name in self.active_adapters:
|
||||
return self.active_adapters[adapter_name]
|
||||
|
||||
return None
|
||||
|
||||
def get_available_adapters(self) -> List[str]:
|
||||
"""获取所有可用的适配器名称"""
|
||||
with self.lock:
|
||||
names = []
|
||||
for name, meta in self.metadata.items():
|
||||
if name in self.factories:
|
||||
names.append(f"{name}|{meta['description']}")
|
||||
return names
|
||||
|
||||
def register_adapter(self, name: str, factory: Callable[[], DataSourceAdapter]):
|
||||
"""注册适配器"""
|
||||
with self.lock:
|
||||
self.factories[name] = factory
|
||||
|
||||
async def _connect_adapter(self, name: str):
|
||||
"""连接适配器"""
|
||||
with self.lock:
|
||||
if name not in self.factories:
|
||||
raise ValueError(f"Adapter factory not found: {name}")
|
||||
|
||||
if name not in self.configs:
|
||||
raise ValueError(f"Adapter config not found: {name}")
|
||||
|
||||
factory = self.factories[name]
|
||||
cfg = self.configs[name]
|
||||
|
||||
adapter = factory()
|
||||
await adapter.connect(cfg["config"])
|
||||
|
||||
with self.lock:
|
||||
self.active_adapters[name] = adapter
|
||||
|
||||
async def health_check(self, name: str) -> bool:
|
||||
"""适配器健康检查"""
|
||||
with self.lock:
|
||||
if name not in self.active_adapters:
|
||||
return False
|
||||
adapter = self.active_adapters[name]
|
||||
|
||||
return await adapter.health_check()
|
||||
@ -0,0 +1,332 @@
|
||||
"""配置管理服务 - 对应Go的internal/service/config.go"""
|
||||
import platform
|
||||
import psutil
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Callable, Dict, Any
|
||||
|
||||
from app.models import (
|
||||
ConfigListRequest, ConfigListData, ConfigSection, ConfigItem,
|
||||
ConfigUpdateRequest, ConfigUpdateData, ConfigType,
|
||||
ReloadRequest, ReloadData, SystemStatusData, MemoryInfo
|
||||
)
|
||||
from app.core.config import get_config, reload_config, save_config, Config
|
||||
from app.core.logger import info
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""配置管理服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.start_time = datetime.now()
|
||||
self.version = "1.0.0"
|
||||
self.callbacks: Dict[ConfigType, List[Callable]] = {}
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def get_config_list(self, req: ConfigListRequest) -> ConfigListData:
|
||||
"""获取配置列表"""
|
||||
sections = []
|
||||
|
||||
# 服务器配置
|
||||
if not req.type or req.type == ConfigType.SERVER:
|
||||
sections.append(ConfigSection(
|
||||
name="服务器配置",
|
||||
type=ConfigType.SERVER,
|
||||
description="HTTP服务器相关配置",
|
||||
items=[
|
||||
ConfigItem(
|
||||
key="port",
|
||||
value=self.config.server.port,
|
||||
type="int",
|
||||
description="服务端口",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
ConfigItem(
|
||||
key="mode",
|
||||
value=self.config.server.mode,
|
||||
type="string",
|
||||
description="运行模式: debug/release",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
ConfigItem(
|
||||
key="api_key",
|
||||
value=self.config.server.api_key,
|
||||
type="string",
|
||||
description="API认证密钥",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
]
|
||||
))
|
||||
|
||||
# 数据库配置
|
||||
if not req.type or req.type == ConfigType.DATABASE:
|
||||
sections.append(ConfigSection(
|
||||
name="数据库配置",
|
||||
type=ConfigType.DATABASE,
|
||||
description="PostgreSQL数据库连接配置",
|
||||
items=[
|
||||
ConfigItem(
|
||||
key="host",
|
||||
value=self.config.database.host,
|
||||
type="string",
|
||||
description="数据库主机地址",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
ConfigItem(
|
||||
key="port",
|
||||
value=self.config.database.port,
|
||||
type="int",
|
||||
description="数据库端口",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
ConfigItem(
|
||||
key="user",
|
||||
value=self.config.database.user,
|
||||
type="string",
|
||||
description="数据库用户名",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
ConfigItem(
|
||||
key="password",
|
||||
value="********",
|
||||
type="password",
|
||||
description="数据库密码",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
ConfigItem(
|
||||
key="database",
|
||||
value=self.config.database.database,
|
||||
type="string",
|
||||
description="数据库名",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
]
|
||||
))
|
||||
|
||||
# Redis配置
|
||||
if not req.type or req.type == ConfigType.REDIS:
|
||||
sections.append(ConfigSection(
|
||||
name="Redis配置",
|
||||
type=ConfigType.REDIS,
|
||||
description="Redis缓存配置",
|
||||
items=[
|
||||
ConfigItem(
|
||||
key="host",
|
||||
value=self.config.redis.host,
|
||||
type="string",
|
||||
description="Redis主机地址",
|
||||
editable=True,
|
||||
required=False
|
||||
),
|
||||
ConfigItem(
|
||||
key="port",
|
||||
value=self.config.redis.port,
|
||||
type="int",
|
||||
description="Redis端口",
|
||||
editable=True,
|
||||
required=False
|
||||
),
|
||||
ConfigItem(
|
||||
key="password",
|
||||
value="********",
|
||||
type="password",
|
||||
description="Redis密码",
|
||||
editable=True,
|
||||
required=False
|
||||
),
|
||||
ConfigItem(
|
||||
key="db",
|
||||
value=self.config.redis.db,
|
||||
type="int",
|
||||
description="Redis数据库编号",
|
||||
editable=True,
|
||||
required=False
|
||||
),
|
||||
]
|
||||
))
|
||||
|
||||
# 数据源配置
|
||||
if not req.type or req.type == ConfigType.SOURCE:
|
||||
sections.append(ConfigSection(
|
||||
name="数据源配置",
|
||||
type=ConfigType.SOURCE,
|
||||
description="股票和期货数据源配置",
|
||||
items=[
|
||||
ConfigItem(
|
||||
key="stock_active",
|
||||
value=self.config.sources.stock.active,
|
||||
type="string",
|
||||
description="股票数据源适配器",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
ConfigItem(
|
||||
key="futures_active",
|
||||
value=self.config.sources.futures.active,
|
||||
type="string",
|
||||
description="期货数据源适配器",
|
||||
editable=True,
|
||||
required=True
|
||||
),
|
||||
]
|
||||
))
|
||||
|
||||
return ConfigListData(
|
||||
sections=sections,
|
||||
version=self.version,
|
||||
updated=datetime.now()
|
||||
)
|
||||
|
||||
def update_config(self, req: ConfigUpdateRequest) -> ConfigUpdateData:
|
||||
"""更新配置"""
|
||||
need_restart = False
|
||||
|
||||
with self.lock:
|
||||
if req.type == ConfigType.SERVER:
|
||||
if "port" in req.items:
|
||||
self.config.server.port = int(req.items["port"])
|
||||
need_restart = True
|
||||
if "mode" in req.items:
|
||||
self.config.server.mode = req.items["mode"]
|
||||
if "api_key" in req.items:
|
||||
self.config.server.api_key = req.items["api_key"]
|
||||
|
||||
elif req.type == ConfigType.DATABASE:
|
||||
if "host" in req.items:
|
||||
self.config.database.host = req.items["host"]
|
||||
need_restart = True
|
||||
if "port" in req.items:
|
||||
self.config.database.port = int(req.items["port"])
|
||||
need_restart = True
|
||||
if "user" in req.items:
|
||||
self.config.database.user = req.items["user"]
|
||||
need_restart = True
|
||||
if "password" in req.items:
|
||||
password = req.items["password"]
|
||||
if password != "********":
|
||||
self.config.database.password = password
|
||||
need_restart = True
|
||||
if "database" in req.items:
|
||||
self.config.database.database = req.items["database"]
|
||||
need_restart = True
|
||||
|
||||
elif req.type == ConfigType.SOURCE:
|
||||
if "stock_active" in req.items:
|
||||
self.config.sources.stock.active = req.items["stock_active"]
|
||||
if "futures_active" in req.items:
|
||||
self.config.sources.futures.active = req.items["futures_active"]
|
||||
|
||||
# 保存到文件
|
||||
try:
|
||||
save_config(self.config)
|
||||
self._trigger_callbacks(req.type)
|
||||
|
||||
message = "配置更新成功"
|
||||
if need_restart:
|
||||
message += ",部分配置需要重启服务后生效"
|
||||
|
||||
return ConfigUpdateData(
|
||||
success=True,
|
||||
need_restart=need_restart,
|
||||
message=message
|
||||
)
|
||||
except Exception as e:
|
||||
return ConfigUpdateData(
|
||||
success=False,
|
||||
need_restart=False,
|
||||
message=f"配置保存失败: {e}"
|
||||
)
|
||||
|
||||
def reload_config(self, req: ReloadRequest) -> ReloadData:
|
||||
"""热加载配置"""
|
||||
try:
|
||||
with self.lock:
|
||||
new_config = reload_config()
|
||||
|
||||
# 根据类型选择性更新
|
||||
if req.config_type is None:
|
||||
self.config = new_config
|
||||
else:
|
||||
if req.config_type == ConfigType.SERVER:
|
||||
self.config.server = new_config.server
|
||||
elif req.config_type == ConfigType.DATABASE:
|
||||
self.config.database = new_config.database
|
||||
elif req.config_type == ConfigType.REDIS:
|
||||
self.config.redis = new_config.redis
|
||||
elif req.config_type == ConfigType.SOURCE:
|
||||
self.config.sources = new_config.sources
|
||||
|
||||
self._trigger_callbacks(req.config_type)
|
||||
|
||||
return ReloadData(
|
||||
success=True,
|
||||
message="配置热加载成功"
|
||||
)
|
||||
except Exception as e:
|
||||
return ReloadData(
|
||||
success=False,
|
||||
message=f"加载配置失败: {e}"
|
||||
)
|
||||
|
||||
def get_system_status(self) -> SystemStatusData:
|
||||
"""获取系统状态"""
|
||||
# 获取内存信息
|
||||
mem = psutil.virtual_memory()
|
||||
|
||||
# 计算运行时长
|
||||
uptime = datetime.now() - self.start_time
|
||||
uptime_str = self._format_duration(uptime)
|
||||
|
||||
return SystemStatusData(
|
||||
status="running",
|
||||
version=self.version,
|
||||
start_time=self.start_time,
|
||||
uptime=uptime_str,
|
||||
python_version=platform.python_version(),
|
||||
memory=MemoryInfo(
|
||||
alloc=mem.used,
|
||||
total_alloc=mem.total,
|
||||
sys=mem.total,
|
||||
num_gc=0 # Python不需要显式GC计数
|
||||
),
|
||||
threads=threading.active_count()
|
||||
)
|
||||
|
||||
def _format_duration(self, d: timedelta) -> str:
|
||||
"""格式化持续时间"""
|
||||
days = d.days
|
||||
hours, remainder = divmod(d.seconds, 3600)
|
||||
minutes, _ = divmod(remainder, 60)
|
||||
|
||||
if days > 0:
|
||||
return f"{days}天{hours}小时{minutes}分钟"
|
||||
if hours > 0:
|
||||
return f"{hours}小时{minutes}分钟"
|
||||
return f"{minutes}分钟"
|
||||
|
||||
def register_callback(self, config_type: ConfigType, callback: Callable):
|
||||
"""注册配置变更回调"""
|
||||
with self.lock:
|
||||
if config_type not in self.callbacks:
|
||||
self.callbacks[config_type] = []
|
||||
self.callbacks[config_type].append(callback)
|
||||
|
||||
def _trigger_callbacks(self, config_type: Optional[ConfigType]):
|
||||
"""触发回调"""
|
||||
with self.lock:
|
||||
# 触发特定类型的回调
|
||||
if config_type and config_type in self.callbacks:
|
||||
for cb in self.callbacks[config_type]:
|
||||
try:
|
||||
cb()
|
||||
except Exception as e:
|
||||
info(f"Callback error: {e}")
|
||||
@ -0,0 +1,102 @@
|
||||
"""期货业务服务 - 对应Go的internal/service/futures.go"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import (
|
||||
KLineQueryRequest, KLineData, SymbolListRequest, SymbolListData,
|
||||
BatchKLineRequest, BatchKLineData, BatchKLineResult, KLineSubData,
|
||||
TradingDatesRequest, TradingDatesData,
|
||||
FuturesContractsRequest, FuturesContractsData
|
||||
)
|
||||
from app.repositories import FuturesRepository
|
||||
from app.core.logger import error
|
||||
|
||||
|
||||
class FuturesService:
|
||||
"""期货业务服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.repository = FuturesRepository(db)
|
||||
|
||||
def query_klines(self, req: KLineQueryRequest) -> KLineData:
|
||||
"""查询K线数据"""
|
||||
# 解析日期
|
||||
try:
|
||||
start = datetime.strptime(req.start, "%Y%m%d")
|
||||
end = datetime.strptime(req.end, "%Y%m%d")
|
||||
end = end + timedelta(days=1) - timedelta(seconds=1)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid date format: {e}")
|
||||
|
||||
# 获取K线数据
|
||||
items = self.repository.get_klines(req.symbol, req.freq, start, end)
|
||||
|
||||
return KLineData(
|
||||
symbol=req.symbol,
|
||||
freq=req.freq,
|
||||
count=len(items),
|
||||
items=items
|
||||
)
|
||||
|
||||
def list_symbols(self, req: SymbolListRequest) -> SymbolListData:
|
||||
"""查询标的列表"""
|
||||
if req.page <= 0:
|
||||
req.page = 1
|
||||
if req.size <= 0:
|
||||
req.size = 20
|
||||
if req.size > 100:
|
||||
req.size = 100
|
||||
|
||||
symbols, total = self.repository.list_symbols(req)
|
||||
|
||||
return SymbolListData(
|
||||
total=total,
|
||||
page=req.page,
|
||||
size=req.size,
|
||||
items=symbols
|
||||
)
|
||||
|
||||
def batch_query_klines(self, req: BatchKLineRequest) -> BatchKLineData:
|
||||
"""批量查询K线"""
|
||||
results = []
|
||||
|
||||
for symbol in req.symbols:
|
||||
single_req = KLineQueryRequest(
|
||||
symbol=symbol,
|
||||
start=req.start,
|
||||
end=req.end,
|
||||
freq=req.freq
|
||||
)
|
||||
|
||||
try:
|
||||
data = self.query_klines(single_req)
|
||||
results.append(BatchKLineResult(
|
||||
symbol=symbol,
|
||||
success=True,
|
||||
data=KLineSubData(count=data.count, items=data.items)
|
||||
))
|
||||
except Exception as e:
|
||||
error(f"Batch query failed for {symbol}: {e}")
|
||||
results.append(BatchKLineResult(
|
||||
symbol=symbol,
|
||||
success=False,
|
||||
error=str(e)
|
||||
))
|
||||
|
||||
return BatchKLineData(results=results)
|
||||
|
||||
def get_trading_dates(self, req: TradingDatesRequest) -> TradingDatesData:
|
||||
"""获取交易日历"""
|
||||
return self.repository.get_trading_dates(req.start, req.end)
|
||||
|
||||
def get_contracts_by_underlying(
|
||||
self,
|
||||
req: FuturesContractsRequest
|
||||
) -> FuturesContractsData:
|
||||
"""根据品种获取合约"""
|
||||
return self.repository.get_contracts_by_underlying(
|
||||
req.underlying,
|
||||
req.exchange
|
||||
)
|
||||
@ -0,0 +1,4 @@
|
||||
"""WebSocket服务模块"""
|
||||
from .server import WebSocketServer, ws_manager
|
||||
|
||||
__all__ = ["WebSocketServer", "ws_manager"]
|
||||
@ -0,0 +1,210 @@
|
||||
"""WebSocket服务 - 对应Go的internal/websocket/server.go"""
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Set, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.core.logger import info, error
|
||||
|
||||
|
||||
@dataclass
|
||||
class WSClient:
|
||||
"""WebSocket客户端"""
|
||||
id: str
|
||||
websocket: WebSocket
|
||||
subscriptions: Set[str] = field(default_factory=set)
|
||||
|
||||
async def send(self, message: dict):
|
||||
"""发送消息"""
|
||||
try:
|
||||
await self.websocket.send_json(message)
|
||||
except Exception as e:
|
||||
error(f"Failed to send message to client {self.id}: {e}")
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""WebSocket连接管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.clients: Dict[str, WSClient] = {}
|
||||
self.subscriptions: Dict[str, Set[str]] = {} # symbol -> set of client_ids
|
||||
self.max_symbols_per_client = 100
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket, client_id: str) -> WSClient:
|
||||
"""建立连接"""
|
||||
await websocket.accept()
|
||||
|
||||
client = WSClient(id=client_id, websocket=websocket)
|
||||
|
||||
async with self.lock:
|
||||
self.clients[client_id] = client
|
||||
|
||||
info(f"WebSocket client connected: {client_id}, total: {len(self.clients)}")
|
||||
return client
|
||||
|
||||
async def disconnect(self, client_id: str):
|
||||
"""断开连接"""
|
||||
async with self.lock:
|
||||
if client_id in self.clients:
|
||||
client = self.clients.pop(client_id)
|
||||
|
||||
# 清理订阅
|
||||
for symbol in client.subscriptions:
|
||||
if symbol in self.subscriptions:
|
||||
self.subscriptions[symbol].discard(client_id)
|
||||
if not self.subscriptions[symbol]:
|
||||
del self.subscriptions[symbol]
|
||||
|
||||
info(f"WebSocket client disconnected: {client_id}, total: {len(self.clients)}")
|
||||
|
||||
async def subscribe(self, client_id: str, symbols: list) -> bool:
|
||||
"""订阅标的"""
|
||||
async with self.lock:
|
||||
if client_id not in self.clients:
|
||||
return False
|
||||
|
||||
client = self.clients[client_id]
|
||||
|
||||
# 检查订阅数量限制
|
||||
if len(client.subscriptions) + len(symbols) > self.max_symbols_per_client:
|
||||
return False
|
||||
|
||||
for symbol in symbols:
|
||||
client.subscriptions.add(symbol)
|
||||
|
||||
if symbol not in self.subscriptions:
|
||||
self.subscriptions[symbol] = set()
|
||||
self.subscriptions[symbol].add(client_id)
|
||||
|
||||
return True
|
||||
|
||||
async def unsubscribe(self, client_id: str, symbols: list):
|
||||
"""取消订阅"""
|
||||
async with self.lock:
|
||||
if client_id not in self.clients:
|
||||
return
|
||||
|
||||
client = self.clients[client_id]
|
||||
|
||||
for symbol in symbols:
|
||||
client.subscriptions.discard(symbol)
|
||||
|
||||
if symbol in self.subscriptions:
|
||||
self.subscriptions[symbol].discard(client_id)
|
||||
if not self.subscriptions[symbol]:
|
||||
del self.subscriptions[symbol]
|
||||
|
||||
async def broadcast_to_symbol(self, symbol: str, message: dict):
|
||||
"""向订阅了某标的的所有客户端广播"""
|
||||
client_ids = set()
|
||||
|
||||
async with self.lock:
|
||||
if symbol in self.subscriptions:
|
||||
client_ids = self.subscriptions[symbol].copy()
|
||||
|
||||
# 在锁外发送消息
|
||||
for client_id in client_ids:
|
||||
if client_id in self.clients:
|
||||
try:
|
||||
await self.clients[client_id].send(message)
|
||||
except Exception as e:
|
||||
error(f"Failed to broadcast to {client_id}: {e}")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"total_clients": len(self.clients),
|
||||
"total_subscriptions": len(self.subscriptions)
|
||||
}
|
||||
|
||||
|
||||
# 全局WebSocket管理器实例
|
||||
ws_manager = WebSocketManager()
|
||||
|
||||
|
||||
class WebSocketServer:
|
||||
"""WebSocket服务器"""
|
||||
|
||||
def __init__(self):
|
||||
self.manager = ws_manager
|
||||
|
||||
async def handle(self, websocket: WebSocket, client_id: str):
|
||||
"""处理WebSocket连接"""
|
||||
client = await self.manager.connect(websocket, client_id)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 接收消息
|
||||
data = await websocket.receive_text()
|
||||
|
||||
try:
|
||||
msg = json.loads(data)
|
||||
action = msg.get("action")
|
||||
symbols = msg.get("symbols", [])
|
||||
|
||||
if action == "subscribe":
|
||||
success = await self.manager.subscribe(client_id, symbols)
|
||||
if success:
|
||||
await client.send({
|
||||
"type": "ack",
|
||||
"action": "subscribe",
|
||||
"symbols": symbols,
|
||||
"ts": datetime.now().isoformat()
|
||||
})
|
||||
else:
|
||||
await client.send({
|
||||
"type": "error",
|
||||
"code": 1003,
|
||||
"message": "Too many subscriptions or subscription failed",
|
||||
"ts": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
elif action == "unsubscribe":
|
||||
await self.manager.unsubscribe(client_id, symbols)
|
||||
await client.send({
|
||||
"type": "ack",
|
||||
"action": "unsubscribe",
|
||||
"symbols": symbols,
|
||||
"ts": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
else:
|
||||
await client.send({
|
||||
"type": "error",
|
||||
"code": 1001,
|
||||
"message": "Unknown action",
|
||||
"ts": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await client.send({
|
||||
"type": "error",
|
||||
"code": 1000,
|
||||
"message": "Invalid message format",
|
||||
"ts": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
await self.manager.disconnect(client_id)
|
||||
except Exception as e:
|
||||
error(f"WebSocket error for client {client_id}: {e}")
|
||||
await self.manager.disconnect(client_id)
|
||||
|
||||
async def send_heartbeat(self):
|
||||
"""发送心跳(可由定时任务调用)"""
|
||||
message = {
|
||||
"type": "heartbeat",
|
||||
"ts": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 向所有客户端发送心跳
|
||||
clients_copy = list(self.manager.clients.values())
|
||||
for client in clients_copy:
|
||||
try:
|
||||
await client.send(message)
|
||||
except Exception:
|
||||
pass
|
||||
@ -0,0 +1,46 @@
|
||||
{
|
||||
"server": {
|
||||
"port": 8080,
|
||||
"mode": "debug",
|
||||
"api_key": "demo-api-key-2024"
|
||||
},
|
||||
"database": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"user": "postgres",
|
||||
"password": "postgres",
|
||||
"database": "marketdata"
|
||||
},
|
||||
"redis": {
|
||||
"host": "localhost",
|
||||
"port": 6379,
|
||||
"password": "",
|
||||
"db": 0
|
||||
},
|
||||
"sources": {
|
||||
"stock": {
|
||||
"active": "tushare",
|
||||
"list": {
|
||||
"tushare": {
|
||||
"type": "http",
|
||||
"config": {
|
||||
"token": "",
|
||||
"base_url": "https://api.tushare.pro"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"futures": {
|
||||
"active": "tushare",
|
||||
"list": {
|
||||
"tushare": {
|
||||
"type": "http",
|
||||
"config": {
|
||||
"token": "",
|
||||
"base_url": "https://api.tushare.pro"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,44 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "market-data-service"
|
||||
version = "1.0.0"
|
||||
description = "统一行情数据服务 - Python实现"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.32.0",
|
||||
"python-socketio>=5.12.1",
|
||||
"websockets>=14.1",
|
||||
"sqlalchemy>=2.0.36",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"pandas>=2.2.3",
|
||||
"numpy>=2.1.3",
|
||||
"pydantic>=2.10.0",
|
||||
"pydantic-settings>=2.6.1",
|
||||
"python-dotenv>=1.0.1",
|
||||
"PyYAML>=6.0.2",
|
||||
"httpx>=0.28.0",
|
||||
"apscheduler>=3.11.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.3.4",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["app*"]
|
||||
@ -0,0 +1,38 @@
|
||||
# Web Framework
|
||||
fastapi==0.115.0
|
||||
uvicorn[standard]==0.32.0
|
||||
python-socketio==5.12.1
|
||||
websockets==14.1
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.36
|
||||
psycopg2-binary==2.9.10
|
||||
alembic==1.14.0
|
||||
|
||||
# Data Processing
|
||||
pandas==2.2.3
|
||||
numpy==2.1.3
|
||||
|
||||
# Data Source
|
||||
# Note: tushare needs to be installed separately with: pip install tushare
|
||||
tushare==1.4.14
|
||||
|
||||
# Configuration
|
||||
pydantic==2.10.0
|
||||
pydantic-settings==2.6.1
|
||||
python-dotenv==1.0.1
|
||||
PyYAML==6.0.2
|
||||
|
||||
# Utilities
|
||||
python-multipart==0.0.19
|
||||
httpx==0.28.0
|
||||
aiohttp==3.11.10
|
||||
aioredis==2.0.1
|
||||
|
||||
# Monitoring
|
||||
apscheduler==3.11.0
|
||||
|
||||
# Testing
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.24.0
|
||||
httpx==0.28.0
|
||||
@ -0,0 +1,78 @@
|
||||
# 安装脚本使用说明
|
||||
|
||||
本目录包含 Go 环境的自动安装脚本。
|
||||
|
||||
## 脚本列表
|
||||
|
||||
| 脚本 | 适用系统 | 说明 |
|
||||
|------|----------|------|
|
||||
| `install-go-windows.ps1` | Windows 10/11 | PowerShell 安装脚本 |
|
||||
| `install-go-linux.sh` | Linux/macOS | Bash 安装脚本 |
|
||||
|
||||
## 使用方法
|
||||
|
||||
### Windows
|
||||
|
||||
1. **以管理员身份打开 PowerShell**
|
||||
|
||||
2. **执行安装脚本**
|
||||
```powershell
|
||||
cd d:\fs_workspace\market-data-service\scripts
|
||||
.\install-go-windows.ps1
|
||||
```
|
||||
|
||||
3. **等待安装完成**
|
||||
|
||||
脚本会自动:
|
||||
- 下载 Go 1.21.6
|
||||
- 执行安装
|
||||
- 配置环境变量
|
||||
- 设置国内镜像
|
||||
|
||||
4. **重新打开 PowerShell**,验证安装
|
||||
```powershell
|
||||
go version
|
||||
```
|
||||
|
||||
### Linux / macOS
|
||||
|
||||
1. **打开终端**
|
||||
|
||||
2. **执行安装脚本**
|
||||
```bash
|
||||
cd /path/to/market-data-service/scripts
|
||||
chmod +x install-go-linux.sh
|
||||
./install-go-linux.sh
|
||||
```
|
||||
|
||||
3. **使环境变量生效**
|
||||
```bash
|
||||
source ~/.bashrc # Linux Bash
|
||||
source ~/.zshrc # macOS Zsh
|
||||
```
|
||||
|
||||
4. **验证安装**
|
||||
```bash
|
||||
go version
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 脚本需要管理员/Root 权限
|
||||
- 安装过程中需要联网下载
|
||||
- 安装完成后需要重新打开终端
|
||||
|
||||
## 安装后步骤
|
||||
|
||||
安装完成后,返回项目目录启动服务:
|
||||
|
||||
```bash
|
||||
cd d:\fs_workspace\market-data-service
|
||||
go mod download
|
||||
go run ./cmd/server
|
||||
```
|
||||
|
||||
然后访问管理后台:
|
||||
```
|
||||
http://localhost:8080/admin
|
||||
```
|
||||
@ -0,0 +1,116 @@
|
||||
# 修复 Go 依赖脚本
|
||||
param()
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
Write-Host "==============================================" -ForegroundColor Cyan
|
||||
Write-Host " 修复 Go 依赖问题 " -ForegroundColor Cyan
|
||||
Write-Host "==============================================" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
|
||||
# 检查 Go 是否安装
|
||||
Write-Host "[1/5] 检查 Go 环境..." -ForegroundColor Yellow
|
||||
$GoCmd = Get-Command go -ErrorAction SilentlyContinue
|
||||
if (-not $GoCmd) {
|
||||
Write-Error "未找到 Go 命令,请先安装 Go"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$GoVersion = & go version
|
||||
Write-Host " Go 版本: $GoVersion" -ForegroundColor Green
|
||||
|
||||
# 检查 GOPROXY
|
||||
Write-Host "[2/5] 检查 GOPROXY 设置..." -ForegroundColor Yellow
|
||||
$GoProxy = & go env GOPROXY
|
||||
Write-Host " 当前 GOPROXY: $GoProxy" -ForegroundColor Gray
|
||||
|
||||
if ($GoProxy -ne "https://goproxy.cn,direct") {
|
||||
Write-Host " 设置国内镜像..." -ForegroundColor Yellow
|
||||
& go env -w GOPROXY="https://goproxy.cn,direct"
|
||||
Write-Host " GOPROXY 已设置为 https://goproxy.cn,direct" -ForegroundColor Green
|
||||
}
|
||||
|
||||
# 检查 GOPATH
|
||||
Write-Host "[3/5] 检查 GOPATH..." -ForegroundColor Yellow
|
||||
$GoPath = & go env GOPATH
|
||||
Write-Host " GOPATH: $GoPath" -ForegroundColor Gray
|
||||
|
||||
# 进入项目目录
|
||||
Write-Host "[4/5] 进入项目目录..." -ForegroundColor Yellow
|
||||
$ProjectDir = "d:\fs_workspace\market-data-service"
|
||||
if (-not (Test-Path $ProjectDir)) {
|
||||
Write-Error "项目目录不存在: $ProjectDir"
|
||||
exit 1
|
||||
}
|
||||
|
||||
Set-Location $ProjectDir
|
||||
Write-Host " 当前目录: $(Get-Location)" -ForegroundColor Green
|
||||
|
||||
# 清理缓存
|
||||
Write-Host "[5/5] 修复依赖..." -ForegroundColor Yellow
|
||||
|
||||
# 删除旧的模块缓存
|
||||
Write-Host " 清理模块缓存..." -ForegroundColor Gray
|
||||
Remove-Item -Path "go.sum" -ErrorAction SilentlyContinue
|
||||
Remove-Item -Path "$GoPath\pkg\mod\cache" -Recurse -Force -ErrorAction SilentlyContinue
|
||||
|
||||
# 设置环境变量
|
||||
$env:GOPROXY = "https://goproxy.cn,direct"
|
||||
|
||||
# 运行 go mod tidy
|
||||
Write-Host " 运行 go mod tidy..." -ForegroundColor Yellow
|
||||
try {
|
||||
& go mod tidy -v
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host " go mod tidy 成功" -ForegroundColor Green
|
||||
} else {
|
||||
Write-Warning "go mod tidy 返回非零退出码: $LASTEXITCODE"
|
||||
}
|
||||
} catch {
|
||||
Write-Warning "go mod tidy 执行出错: $_"
|
||||
}
|
||||
|
||||
# 下载依赖
|
||||
Write-Host " 运行 go mod download..." -ForegroundColor Yellow
|
||||
try {
|
||||
& go mod download -x
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Host " go mod download 成功" -ForegroundColor Green
|
||||
} else {
|
||||
Write-Warning "go mod download 返回非零退出码: $LASTEXITCODE"
|
||||
}
|
||||
} catch {
|
||||
Write-Warning "go mod download 执行出错: $_"
|
||||
}
|
||||
|
||||
# 验证
|
||||
Write-Host " 验证依赖..." -ForegroundColor Yellow
|
||||
try {
|
||||
& go list -m all | Out-Null
|
||||
Write-Host " 依赖验证成功" -ForegroundColor Green
|
||||
} catch {
|
||||
Write-Warning "依赖验证失败: $_"
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "==============================================" -ForegroundColor Cyan
|
||||
Write-Host " 依赖修复完成 " -ForegroundColor Cyan
|
||||
Write-Host "==============================================" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
|
||||
# 检查 go.sum 是否存在
|
||||
if (Test-Path "go.sum") {
|
||||
Write-Host "✓ go.sum 文件已生成" -ForegroundColor Green
|
||||
} else {
|
||||
Write-Warning "✗ go.sum 文件未生成"
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "接下来可以尝试编译项目:" -ForegroundColor Yellow
|
||||
Write-Host " go build ./cmd/server" -ForegroundColor Gray
|
||||
Write-Host ""
|
||||
Write-Host "或者直接运行:" -ForegroundColor Yellow
|
||||
Write-Host " go run ./cmd/server" -ForegroundColor Gray
|
||||
Write-Host ""
|
||||
|
||||
Read-Host "按 Enter 键退出"
|
||||
@ -0,0 +1,173 @@
|
||||
#!/bin/bash
|
||||
# Go 环境安装脚本 (Linux/macOS)
|
||||
# 适用于 Ubuntu/CentOS/macOS
|
||||
|
||||
set -e
|
||||
|
||||
GO_VERSION="1.21.6"
|
||||
INSTALL_DIR="/usr/local"
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 检测操作系统
|
||||
OS=""
|
||||
ARCH=$(uname -m)
|
||||
|
||||
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
|
||||
OS="linux"
|
||||
if [[ "$ARCH" == "x86_64" ]]; then
|
||||
ARCH="amd64"
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
ARCH="arm64"
|
||||
fi
|
||||
elif [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
OS="darwin"
|
||||
if [[ "$ARCH" == "x86_64" ]]; then
|
||||
ARCH="amd64"
|
||||
elif [[ "$ARCH" == "arm64" ]]; then
|
||||
ARCH="arm64"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}不支持的操作系统: $OSTYPE${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${CYAN}==============================================${NC}"
|
||||
echo -e "${CYAN} Go ${GO_VERSION} 环境安装脚本 (${OS}/${ARCH}) ${NC}"
|
||||
echo -e "${CYAN}==============================================${NC}"
|
||||
echo ""
|
||||
|
||||
# 检查当前 Go 版本
|
||||
echo -e "${YELLOW}[1/6] 检查当前 Go 版本...${NC}"
|
||||
if command -v go &> /dev/null; then
|
||||
CURRENT_VERSION=$(go version)
|
||||
echo -e "${GREEN} 检测到已安装: $CURRENT_VERSION${NC}"
|
||||
|
||||
read -p " 是否重新安装? (y/N): " response
|
||||
if [[ ! "$response" =~ ^[Yy]$ ]]; then
|
||||
echo -e "${GRAY} 跳过安装${NC}"
|
||||
exit 0
|
||||
fi
|
||||
else
|
||||
echo -e " 未检测到 Go"
|
||||
fi
|
||||
|
||||
# 下载安装包
|
||||
echo -e "${YELLOW}[2/6] 下载 Go ${GO_VERSION} ...${NC}"
|
||||
DOWNLOAD_URL="https://go.dev/dl/go${GO_VERSION}.${OS}-${ARCH}.tar.gz"
|
||||
TEMP_FILE="/tmp/go${GO_VERSION}.${OS}-${ARCH}.tar.gz"
|
||||
|
||||
echo " 下载地址: $DOWNLOAD_URL"
|
||||
|
||||
if command -v wget &> /dev/null; then
|
||||
wget -q --show-progress "$DOWNLOAD_URL" -O "$TEMP_FILE"
|
||||
elif command -v curl &> /dev/null; then
|
||||
curl -L --progress-bar "$DOWNLOAD_URL" -o "$TEMP_FILE"
|
||||
else
|
||||
echo -e "${RED}错误: 需要 wget 或 curl${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$TEMP_FILE" ]; then
|
||||
echo -e "${RED}下载失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
FILE_SIZE=$(du -h "$TEMP_FILE" | cut -f1)
|
||||
echo -e "${GREEN} 下载完成: $FILE_SIZE${NC}"
|
||||
|
||||
# 删除旧版本
|
||||
echo -e "${YELLOW}[3/6] 清理旧版本...${NC}"
|
||||
if [ -d "$INSTALL_DIR/go" ]; then
|
||||
echo " 删除旧版本..."
|
||||
sudo rm -rf "$INSTALL_DIR/go"
|
||||
fi
|
||||
|
||||
# 解压安装
|
||||
echo -e "${YELLOW}[4/6] 安装 Go...${NC}"
|
||||
echo " 解压到 $INSTALL_DIR ..."
|
||||
sudo tar -C "$INSTALL_DIR" -xzf "$TEMP_FILE"
|
||||
|
||||
if [ ! -d "$INSTALL_DIR/go/bin" ]; then
|
||||
echo -e "${RED}安装失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN} 安装完成${NC}"
|
||||
|
||||
# 配置环境变量
|
||||
echo -e "${YELLOW}[5/6] 配置环境变量...${NC}"
|
||||
|
||||
# 检测 shell
|
||||
SHELL_NAME=$(basename "$SHELL")
|
||||
RC_FILE=""
|
||||
|
||||
if [[ "$SHELL_NAME" == "bash" ]]; then
|
||||
RC_FILE="$HOME/.bashrc"
|
||||
elif [[ "$SHELL_NAME" == "zsh" ]]; then
|
||||
RC_FILE="$HOME/.zshrc"
|
||||
else
|
||||
RC_FILE="$HOME/.profile"
|
||||
fi
|
||||
|
||||
# 检查是否已配置
|
||||
if ! grep -q "export PATH=.*go/bin" "$RC_FILE" 2>/dev/null; then
|
||||
echo "" >> "$RC_FILE"
|
||||
echo "# Go 环境配置" >> "$RC_FILE"
|
||||
echo "export PATH=\$PATH:$INSTALL_DIR/go/bin" >> "$RC_FILE"
|
||||
echo "export GOPATH=\$HOME/go" >> "$RC_FILE"
|
||||
echo "export PATH=\$PATH:\$GOPATH/bin" >> "$RC_FILE"
|
||||
echo "export GOPROXY=https://goproxy.cn,direct" >> "$RC_FILE"
|
||||
echo -e "${GREEN} 环境变量已添加到 $RC_FILE${NC}"
|
||||
else
|
||||
echo -e " 环境变量已存在"
|
||||
fi
|
||||
|
||||
# 创建 GOPATH 目录
|
||||
mkdir -p "$HOME/go/bin"
|
||||
mkdir -p "$HOME/go/pkg"
|
||||
mkdir -p "$HOME/go/src"
|
||||
|
||||
# 验证安装
|
||||
echo -e "${YELLOW}[6/6] 验证安装...${NC}"
|
||||
|
||||
export PATH=$PATH:$INSTALL_DIR/go/bin
|
||||
export GOPATH=$HOME/go
|
||||
|
||||
if command -v go &> /dev/null; then
|
||||
VERSION=$(go version)
|
||||
echo -e "${GREEN} Go 版本: $VERSION${NC}"
|
||||
|
||||
GOPATH_VAL=$(go env GOPATH)
|
||||
echo -e "${GREEN} GOPATH: $GOPATH_VAL${NC}"
|
||||
|
||||
GOPROXY_VAL=$(go env GOPROXY)
|
||||
echo -e "${GREEN} GOPROXY: $GOPROXY_VAL${NC}"
|
||||
else
|
||||
echo -e "${RED} 验证失败${NC}"
|
||||
fi
|
||||
|
||||
# 清理
|
||||
echo "[清理] 删除安装包..."
|
||||
rm -f "$TEMP_FILE"
|
||||
|
||||
echo ""
|
||||
echo -e "${CYAN}==============================================${NC}"
|
||||
echo -e "${CYAN} Go 安装完成! ${NC}"
|
||||
echo -e "${CYAN}==============================================${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}请运行以下命令使环境变量生效:${NC}"
|
||||
echo " source $RC_FILE"
|
||||
echo ""
|
||||
echo -e "${YELLOW}然后验证安装:${NC}"
|
||||
echo " go version"
|
||||
echo ""
|
||||
echo -e "${YELLOW}接下来可以启动行情数据服务:${NC}"
|
||||
echo " cd d:\fs_workspace\market-data-service"
|
||||
echo " go run ./cmd/server"
|
||||
echo ""
|
||||
Loading…
Reference in new issue