You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
375 lines
8.0 KiB
375 lines
8.0 KiB
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
|
|
"market-data-service/api"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // 允许所有来源,生产环境需要限制
|
|
},
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
}
|
|
|
|
// Hub WebSocket连接管理中心
|
|
type Hub struct {
|
|
// 已注册的客户端
|
|
clients map[*Client]bool
|
|
|
|
// 广播消息通道
|
|
broadcast chan []byte
|
|
|
|
// 注册请求通道
|
|
register chan *Client
|
|
|
|
// 注销请求通道
|
|
unregister chan *Client
|
|
|
|
// 标的订阅映射: symbol -> clients
|
|
subscriptions map[string]map[*Client]bool
|
|
|
|
// 保护subscriptions的锁
|
|
subMu sync.RWMutex
|
|
|
|
// 最大订阅标的数
|
|
maxSymbolsPerClient int
|
|
}
|
|
|
|
// NewHub 创建Hub
|
|
func NewHub() *Hub {
|
|
return &Hub{
|
|
clients: make(map[*Client]bool),
|
|
broadcast: make(chan []byte),
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
subscriptions: make(map[string]map[*Client]bool),
|
|
maxSymbolsPerClient: 100,
|
|
}
|
|
}
|
|
|
|
// Run 启动Hub
|
|
func (h *Hub) Run() {
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.clients[client] = true
|
|
log.Printf("Client registered, total: %d", len(h.clients))
|
|
|
|
case client := <-h.unregister:
|
|
if _, ok := h.clients[client]; ok {
|
|
delete(h.clients, client)
|
|
close(client.send)
|
|
// 清理订阅
|
|
h.removeAllSubscriptions(client)
|
|
log.Printf("Client unregistered, total: %d", len(h.clients))
|
|
}
|
|
|
|
case message := <-h.broadcast:
|
|
for client := range h.clients {
|
|
select {
|
|
case client.send <- message:
|
|
default:
|
|
// 发送缓冲满,关闭连接
|
|
close(client.send)
|
|
delete(h.clients, client)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Subscribe 客户端订阅标的
|
|
func (h *Hub) Subscribe(client *Client, symbols []string) error {
|
|
if len(client.subscriptions)+len(symbols) > h.maxSymbolsPerClient {
|
|
return api.ErrRateLimit
|
|
}
|
|
|
|
h.subMu.Lock()
|
|
defer h.subMu.Unlock()
|
|
|
|
for _, symbol := range symbols {
|
|
if _, ok := h.subscriptions[symbol]; !ok {
|
|
h.subscriptions[symbol] = make(map[*Client]bool)
|
|
}
|
|
h.subscriptions[symbol][client] = true
|
|
client.subscriptions[symbol] = true
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Unsubscribe 客户端取消订阅
|
|
func (h *Hub) Unsubscribe(client *Client, symbols []string) {
|
|
h.subMu.Lock()
|
|
defer h.subMu.Unlock()
|
|
|
|
for _, symbol := range symbols {
|
|
if clients, ok := h.subscriptions[symbol]; ok {
|
|
delete(clients, client)
|
|
if len(clients) == 0 {
|
|
delete(h.subscriptions, symbol)
|
|
}
|
|
}
|
|
delete(client.subscriptions, symbol)
|
|
}
|
|
}
|
|
|
|
// removeAllSubscriptions 移除客户端所有订阅
|
|
func (h *Hub) removeAllSubscriptions(client *Client) {
|
|
h.subMu.Lock()
|
|
defer h.subMu.Unlock()
|
|
|
|
for symbol := range client.subscriptions {
|
|
if clients, ok := h.subscriptions[symbol]; ok {
|
|
delete(clients, client)
|
|
if len(clients) == 0 {
|
|
delete(h.subscriptions, symbol)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastToSymbol 向订阅了某标的的所有客户端广播
|
|
func (h *Hub) BroadcastToSymbol(symbol string, data []byte) {
|
|
h.subMu.RLock()
|
|
clients := h.subscriptions[symbol]
|
|
h.subMu.RUnlock()
|
|
|
|
for client := range clients {
|
|
select {
|
|
case client.send <- data:
|
|
default:
|
|
// 发送缓冲满,稍后处理
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetSubscriptionStats 获取订阅统计
|
|
func (h *Hub) GetSubscriptionStats() map[string]interface{} {
|
|
h.subMu.RLock()
|
|
defer h.subMu.RUnlock()
|
|
|
|
return map[string]interface{}{
|
|
"total_clients": len(h.clients),
|
|
"total_subscriptions": len(h.subscriptions),
|
|
}
|
|
}
|
|
|
|
// Client WebSocket客户端
|
|
type Client struct {
|
|
hub *Hub
|
|
conn *websocket.Conn
|
|
send chan []byte
|
|
|
|
// 已订阅的标的
|
|
subscriptions map[string]bool
|
|
subMu sync.RWMutex
|
|
}
|
|
|
|
// NewClient 创建客户端
|
|
func NewClient(hub *Hub, conn *websocket.Conn) *Client {
|
|
return &Client{
|
|
hub: hub,
|
|
conn: conn,
|
|
send: make(chan []byte, 256),
|
|
subscriptions: make(map[string]bool),
|
|
}
|
|
}
|
|
|
|
// ReadPump 读取客户端消息
|
|
func (c *Client) ReadPump() {
|
|
defer func() {
|
|
c.hub.unregister <- c
|
|
c.conn.Close()
|
|
}()
|
|
|
|
c.conn.SetReadLimit(512 * 1024) // 512KB
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
_, message, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("WebSocket error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
// 处理客户端消息
|
|
c.handleMessage(message)
|
|
}
|
|
}
|
|
|
|
// WritePump 向客户端写入消息
|
|
func (c *Client) WritePump() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.conn.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if !ok {
|
|
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
c.conn.WriteMessage(websocket.TextMessage, message)
|
|
|
|
case <-ticker.C:
|
|
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleMessage 处理客户端消息
|
|
func (c *Client) handleMessage(data []byte) {
|
|
var msg ClientMessage
|
|
if err := json.Unmarshal(data, &msg); err != nil {
|
|
c.sendError(1000, "Invalid message format")
|
|
return
|
|
}
|
|
|
|
switch msg.Action {
|
|
case "subscribe":
|
|
c.handleSubscribe(msg.Symbols)
|
|
case "unsubscribe":
|
|
c.handleUnsubscribe(msg.Symbols)
|
|
default:
|
|
c.sendError(1001, "Unknown action")
|
|
}
|
|
}
|
|
|
|
// handleSubscribe 处理订阅请求
|
|
func (c *Client) handleSubscribe(symbols []string) {
|
|
if len(symbols) == 0 {
|
|
c.sendError(1002, "Symbols cannot be empty")
|
|
return
|
|
}
|
|
|
|
if err := c.hub.Subscribe(c, symbols); err != nil {
|
|
c.sendError(1003, err.Error())
|
|
return
|
|
}
|
|
|
|
// 发送确认
|
|
ack := map[string]interface{}{
|
|
"type": "ack",
|
|
"action": "subscribe",
|
|
"symbols": symbols,
|
|
"ts": time.Now().Format(time.RFC3339),
|
|
}
|
|
data, _ := json.Marshal(ack)
|
|
c.send <- data
|
|
}
|
|
|
|
// handleUnsubscribe 处理取消订阅请求
|
|
func (c *Client) handleUnsubscribe(symbols []string) {
|
|
c.hub.Unsubscribe(c, symbols)
|
|
|
|
ack := map[string]interface{}{
|
|
"type": "ack",
|
|
"action": "unsubscribe",
|
|
"symbols": symbols,
|
|
"ts": time.Now().Format(time.RFC3339),
|
|
}
|
|
data, _ := json.Marshal(ack)
|
|
c.send <- data
|
|
}
|
|
|
|
// sendError 发送错误消息
|
|
func (c *Client) sendError(code int, message string) {
|
|
err := map[string]interface{}{
|
|
"type": "error",
|
|
"code": code,
|
|
"message": message,
|
|
"ts": time.Now().Format(time.RFC3339),
|
|
}
|
|
data, _ := json.Marshal(err)
|
|
c.send <- data
|
|
}
|
|
|
|
// ClientMessage 客户端消息结构
|
|
type ClientMessage struct {
|
|
Action string `json:"action"`
|
|
Symbols []string `json:"symbols"`
|
|
}
|
|
|
|
// Server WebSocket服务器
|
|
type Server struct {
|
|
hub *Hub
|
|
}
|
|
|
|
// NewServer 创建WebSocket服务器
|
|
func NewServer(hub *Hub) *Server {
|
|
return &Server{hub: hub}
|
|
}
|
|
|
|
// HandleWebSocket 处理WebSocket连接
|
|
func (s *Server) HandleWebSocket(c *gin.Context) {
|
|
// 认证检查
|
|
apiKey := c.GetHeader("X-API-Key")
|
|
if apiKey == "" {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing API Key"})
|
|
return
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
log.Printf("WebSocket upgrade failed: %v", err)
|
|
return
|
|
}
|
|
|
|
client := NewClient(s.hub, conn)
|
|
s.hub.register <- client
|
|
|
|
go client.WritePump()
|
|
go client.ReadPump()
|
|
}
|
|
|
|
// BroadcastTick 广播Tick数据
|
|
func (s *Server) BroadcastTick(symbol string, tick map[string]interface{}) {
|
|
data, err := json.Marshal(tick)
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.hub.BroadcastToSymbol(symbol, data)
|
|
}
|
|
|
|
// BroadcastKLine 广播K线闭合数据
|
|
func (s *Server) BroadcastKLine(symbol string, freq string, kline map[string]interface{}) {
|
|
msg := map[string]interface{}{
|
|
"type": "klines",
|
|
"symbol": symbol,
|
|
"freq": freq,
|
|
"data": kline,
|
|
"ts": time.Now().Format(time.RFC3339),
|
|
}
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.hub.BroadcastToSymbol(symbol, data)
|
|
}
|