You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
343 lines
10 KiB
343 lines
10 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
===================================
|
|
命令分发器
|
|
===================================
|
|
|
|
负责解析命令、匹配处理器、分发执行。
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional, Type, Callable
|
|
|
|
from bot.models import BotMessage, BotResponse
|
|
from bot.commands.base import BotCommand
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""
|
|
简单的频率限制器
|
|
|
|
基于滑动窗口算法,限制每个用户的请求频率。
|
|
"""
|
|
|
|
def __init__(self, max_requests: int = 10, window_seconds: int = 60):
|
|
"""
|
|
Args:
|
|
max_requests: 窗口内最大请求数
|
|
window_seconds: 窗口时间(秒)
|
|
"""
|
|
self.max_requests = max_requests
|
|
self.window_seconds = window_seconds
|
|
self._requests: Dict[str, List[float]] = defaultdict(list)
|
|
|
|
def is_allowed(self, user_id: str) -> bool:
|
|
"""
|
|
检查用户是否允许请求
|
|
|
|
Args:
|
|
user_id: 用户标识
|
|
|
|
Returns:
|
|
是否允许
|
|
"""
|
|
now = time.time()
|
|
window_start = now - self.window_seconds
|
|
|
|
# 清理过期记录
|
|
self._requests[user_id] = [
|
|
t for t in self._requests[user_id]
|
|
if t > window_start
|
|
]
|
|
|
|
# 检查是否超限
|
|
if len(self._requests[user_id]) >= self.max_requests:
|
|
return False
|
|
|
|
# 记录本次请求
|
|
self._requests[user_id].append(now)
|
|
return True
|
|
|
|
def get_remaining(self, user_id: str) -> int:
|
|
"""获取剩余可用请求数"""
|
|
now = time.time()
|
|
window_start = now - self.window_seconds
|
|
|
|
# 清理过期记录
|
|
self._requests[user_id] = [
|
|
t for t in self._requests[user_id]
|
|
if t > window_start
|
|
]
|
|
|
|
return max(0, self.max_requests - len(self._requests[user_id]))
|
|
|
|
|
|
class CommandDispatcher:
|
|
"""
|
|
命令分发器
|
|
|
|
职责:
|
|
1. 注册和管理命令处理器
|
|
2. 解析消息中的命令和参数
|
|
3. 分发命令到对应处理器
|
|
4. 处理未知命令和错误
|
|
|
|
使用示例:
|
|
dispatcher = CommandDispatcher()
|
|
dispatcher.register(AnalyzeCommand())
|
|
dispatcher.register(HelpCommand())
|
|
|
|
response = dispatcher.dispatch(message)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
command_prefix: str = "/",
|
|
rate_limit_requests: int = 10,
|
|
rate_limit_window: int = 60,
|
|
admin_users: Optional[List[str]] = None
|
|
):
|
|
"""
|
|
Args:
|
|
command_prefix: 命令前缀,默认 "/"
|
|
rate_limit_requests: 频率限制:窗口内最大请求数
|
|
rate_limit_window: 频率限制:窗口时间(秒)
|
|
admin_users: 管理员用户 ID 列表
|
|
"""
|
|
self.command_prefix = command_prefix
|
|
self.admin_users = set(admin_users or [])
|
|
|
|
self._commands: Dict[str, BotCommand] = {}
|
|
self._aliases: Dict[str, str] = {}
|
|
self._rate_limiter = RateLimiter(rate_limit_requests, rate_limit_window)
|
|
|
|
# 回调函数:获取帮助命令的命令列表
|
|
self._help_command_getter: Optional[Callable] = None
|
|
|
|
def register(self, command: BotCommand) -> None:
|
|
"""
|
|
注册命令
|
|
|
|
Args:
|
|
command: 命令实例
|
|
"""
|
|
name = command.name.lower()
|
|
|
|
if name in self._commands:
|
|
logger.warning(f"[Dispatcher] 命令 '{name}' 已存在,将被覆盖")
|
|
|
|
self._commands[name] = command
|
|
logger.debug(f"[Dispatcher] 注册命令: {name}")
|
|
|
|
# 注册别名
|
|
for alias in command.aliases:
|
|
alias_lower = alias.lower()
|
|
if alias_lower in self._aliases:
|
|
logger.warning(f"[Dispatcher] 别名 '{alias_lower}' 已存在,将被覆盖")
|
|
self._aliases[alias_lower] = name
|
|
logger.debug(f"[Dispatcher] 注册别名: {alias_lower} -> {name}")
|
|
|
|
def register_class(self, command_class: Type[BotCommand]) -> None:
|
|
"""
|
|
注册命令类(自动实例化)
|
|
|
|
Args:
|
|
command_class: 命令类
|
|
"""
|
|
self.register(command_class())
|
|
|
|
def unregister(self, name: str) -> bool:
|
|
"""
|
|
注销命令
|
|
|
|
Args:
|
|
name: 命令名称
|
|
|
|
Returns:
|
|
是否成功注销
|
|
"""
|
|
name = name.lower()
|
|
|
|
if name not in self._commands:
|
|
return False
|
|
|
|
command = self._commands.pop(name)
|
|
|
|
# 移除别名
|
|
for alias in command.aliases:
|
|
self._aliases.pop(alias.lower(), None)
|
|
|
|
logger.debug(f"[Dispatcher] 注销命令: {name}")
|
|
return True
|
|
|
|
def get_command(self, name: str) -> Optional[BotCommand]:
|
|
"""
|
|
获取命令
|
|
|
|
支持命令名和别名查询。
|
|
|
|
Args:
|
|
name: 命令名或别名
|
|
|
|
Returns:
|
|
命令实例,或 None
|
|
"""
|
|
name = name.lower()
|
|
|
|
# 先查命令名
|
|
if name in self._commands:
|
|
return self._commands[name]
|
|
|
|
# 再查别名
|
|
if name in self._aliases:
|
|
return self._commands.get(self._aliases[name])
|
|
|
|
return None
|
|
|
|
def list_commands(self, include_hidden: bool = False) -> List[BotCommand]:
|
|
"""
|
|
列出所有命令
|
|
|
|
Args:
|
|
include_hidden: 是否包含隐藏命令
|
|
|
|
Returns:
|
|
命令列表
|
|
"""
|
|
commands = list(self._commands.values())
|
|
|
|
if not include_hidden:
|
|
commands = [c for c in commands if not c.hidden]
|
|
|
|
return sorted(commands, key=lambda c: c.name)
|
|
|
|
def is_admin(self, user_id: str) -> bool:
|
|
"""检查用户是否是管理员"""
|
|
return user_id in self.admin_users
|
|
|
|
def add_admin(self, user_id: str) -> None:
|
|
"""添加管理员"""
|
|
self.admin_users.add(user_id)
|
|
|
|
def remove_admin(self, user_id: str) -> None:
|
|
"""移除管理员"""
|
|
self.admin_users.discard(user_id)
|
|
|
|
def dispatch(self, message: BotMessage) -> BotResponse:
|
|
"""
|
|
分发消息到对应命令
|
|
|
|
Args:
|
|
message: 消息对象
|
|
|
|
Returns:
|
|
响应对象
|
|
"""
|
|
# 1. 检查频率限制
|
|
if not self._rate_limiter.is_allowed(message.user_id):
|
|
remaining_time = self._rate_limiter.window_seconds
|
|
return BotResponse.error_response(
|
|
f"请求过于频繁,请 {remaining_time} 秒后再试"
|
|
)
|
|
|
|
# 2. 解析命令和参数
|
|
cmd_name, args = message.get_command_and_args(self.command_prefix)
|
|
|
|
if cmd_name is None:
|
|
# 不是命令,检查是否 @了机器人
|
|
if message.mentioned:
|
|
return BotResponse.text_response(
|
|
"你好!我是股票分析助手。\n"
|
|
f"发送 `{self.command_prefix}help` 查看可用命令。"
|
|
)
|
|
# 非命令消息,不处理
|
|
return BotResponse.text_response("")
|
|
|
|
logger.info(f"[Dispatcher] 收到命令: {cmd_name}, 参数: {args}, 用户: {message.user_name}")
|
|
|
|
# 3. 查找命令处理器
|
|
command = self.get_command(cmd_name)
|
|
|
|
if command is None:
|
|
return BotResponse.error_response(
|
|
f"未知命令: {cmd_name}\n"
|
|
f"发送 `{self.command_prefix}help` 查看可用命令。"
|
|
)
|
|
|
|
# 4. 检查权限
|
|
if command.admin_only and not self.is_admin(message.user_id):
|
|
return BotResponse.error_response("此命令需要管理员权限")
|
|
|
|
# 5. 验证参数
|
|
error_msg = command.validate_args(args)
|
|
if error_msg:
|
|
return BotResponse.error_response(
|
|
f"{error_msg}\n用法: `{command.usage}`"
|
|
)
|
|
|
|
# 6. 执行命令
|
|
try:
|
|
response = command.execute(message, args)
|
|
logger.info(f"[Dispatcher] 命令 {cmd_name} 执行成功")
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"[Dispatcher] 命令 {cmd_name} 执行失败: {e}")
|
|
logger.exception(e)
|
|
return BotResponse.error_response(f"命令执行失败: {str(e)[:100]}")
|
|
|
|
def set_help_command_getter(self, getter: Callable) -> None:
|
|
"""
|
|
设置帮助命令的命令列表获取器
|
|
|
|
用于让 HelpCommand 获取命令列表。
|
|
|
|
Args:
|
|
getter: 回调函数,返回命令列表
|
|
"""
|
|
self._help_command_getter = getter
|
|
|
|
|
|
# 全局分发器实例
|
|
_dispatcher: Optional[CommandDispatcher] = None
|
|
|
|
|
|
def get_dispatcher() -> CommandDispatcher:
|
|
"""
|
|
获取全局分发器实例
|
|
|
|
使用单例模式,首次调用时自动初始化并注册所有命令。
|
|
"""
|
|
global _dispatcher
|
|
|
|
if _dispatcher is None:
|
|
from src.config import get_config
|
|
|
|
config = get_config()
|
|
|
|
# 创建分发器
|
|
_dispatcher = CommandDispatcher(
|
|
command_prefix=getattr(config, 'bot_command_prefix', '/'),
|
|
rate_limit_requests=getattr(config, 'bot_rate_limit_requests', 10),
|
|
rate_limit_window=getattr(config, 'bot_rate_limit_window', 60),
|
|
admin_users=getattr(config, 'bot_admin_users', []),
|
|
)
|
|
|
|
# 自动注册所有命令
|
|
from bot.commands import ALL_COMMANDS
|
|
for command_class in ALL_COMMANDS:
|
|
_dispatcher.register_class(command_class)
|
|
|
|
logger.info(f"[Dispatcher] 初始化完成,已注册 {len(_dispatcher._commands)} 个命令")
|
|
|
|
return _dispatcher
|
|
|
|
|
|
def reset_dispatcher() -> None:
|
|
"""重置全局分发器(主要用于测试)"""
|
|
global _dispatcher
|
|
_dispatcher = None
|