You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

392 lines
13 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""
===================================
PytdxFetcher - 通达信数据源 (Priority 2)
===================================
数据来源通达信行情服务器pytdx 库)
特点:免费、无需 Token、直连行情服务器
优点:实时数据、稳定、无配额限制
关键策略:
1. 多服务器自动切换
2. 连接超时自动重连
3. 失败后指数退避重试
"""
import logging
import re
from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Generator, List, Tuple
import pandas as pd
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
from .base import BaseFetcher, DataFetchError, STANDARD_COLUMNS
import os
logger = logging.getLogger(__name__)
def _is_us_code(stock_code: str) -> bool:
"""
判断代码是否为美股
美股代码规则:
- 1-5个大写字母'AAPL', 'TSLA'
- 可能包含 '.',如 'BRK.B'
"""
code = stock_code.strip().upper()
return bool(re.match(r'^[A-Z]{1,5}(\.[A-Z])?$', code))
class PytdxFetcher(BaseFetcher):
"""
通达信数据源实现
优先级2与 Tushare 同级)
数据来源:通达信行情服务器
关键策略:
- 自动选择最优服务器
- 连接失败自动切换服务器
- 失败后指数退避重试
Pytdx 特点:
- 免费、无需注册
- 直连行情服务器
- 支持实时行情和历史数据
- 支持股票名称查询
"""
name = "PytdxFetcher"
priority = int(os.getenv("PYTDX_PRIORITY", "2"))
# 默认通达信行情服务器列表
DEFAULT_HOSTS = [
("119.147.212.81", 7709), # 深圳
("112.74.214.43", 7727), # 深圳
("221.231.141.60", 7709), # 上海
("101.227.73.20", 7709), # 上海
("101.227.77.254", 7709), # 上海
("14.215.128.18", 7709), # 广州
("59.173.18.140", 7709), # 武汉
("180.153.39.51", 7709), # 杭州
]
def __init__(self, hosts: Optional[List[Tuple[str, int]]] = None):
"""
初始化 PytdxFetcher
Args:
hosts: 服务器列表 [(host, port), ...],默认使用内置列表
"""
self._hosts = hosts or self.DEFAULT_HOSTS
self._api = None
self._connected = False
self._current_host_idx = 0
self._stock_list_cache = None # 股票列表缓存
self._stock_name_cache = {} # 股票名称缓存 {code: name}
def _get_pytdx(self):
"""
延迟加载 pytdx 模块
只在首次使用时导入,避免未安装时报错
"""
try:
from pytdx.hq import TdxHq_API
return TdxHq_API
except ImportError:
logger.warning("pytdx 未安装,请运行: pip install pytdx")
return None
@contextmanager
def _pytdx_session(self) -> Generator:
"""
Pytdx 连接上下文管理器
确保:
1. 进入上下文时自动连接
2. 退出上下文时自动断开
3. 异常时也能正确断开
使用示例:
with self._pytdx_session() as api:
# 在这里执行数据查询
"""
TdxHq_API = self._get_pytdx()
if TdxHq_API is None:
raise DataFetchError("pytdx 库未安装")
api = TdxHq_API()
connected = False
try:
# 尝试连接服务器(自动选择最优)
for i in range(len(self._hosts)):
host_idx = (self._current_host_idx + i) % len(self._hosts)
host, port = self._hosts[host_idx]
try:
if api.connect(host, port, time_out=5):
connected = True
self._current_host_idx = host_idx
logger.debug(f"Pytdx 连接成功: {host}:{port}")
break
except Exception as e:
logger.debug(f"Pytdx 连接 {host}:{port} 失败: {e}")
continue
if not connected:
raise DataFetchError("Pytdx 无法连接任何服务器")
yield api
finally:
# 确保断开连接
try:
api.disconnect()
logger.debug("Pytdx 连接已断开")
except Exception as e:
logger.warning(f"Pytdx 断开连接时出错: {e}")
def _get_market_code(self, stock_code: str) -> Tuple[int, str]:
"""
根据股票代码判断市场
Pytdx 市场代码:
- 0: 深圳
- 1: 上海
Args:
stock_code: 股票代码
Returns:
(market, code) 元组
"""
code = stock_code.strip()
# 去除可能的前缀后缀
code = code.replace('.SH', '').replace('.SZ', '')
code = code.replace('.sh', '').replace('.sz', '')
code = code.replace('sh', '').replace('sz', '')
# 根据代码前缀判断市场
# 上海60xxxx, 68xxxx科创板
# 深圳00xxxx, 30xxxx创业板, 002xxx中小板
if code.startswith(('60', '68')):
return 1, code # 上海
else:
return 0, code # 深圳
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=30),
retry=retry_if_exception_type((ConnectionError, TimeoutError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def _fetch_raw_data(self, stock_code: str, start_date: str, end_date: str) -> pd.DataFrame:
"""
从通达信获取原始数据
使用 get_security_bars() 获取日线数据
流程:
1. 检查是否为美股(不支持)
2. 使用上下文管理器管理连接
3. 判断市场代码
4. 调用 API 获取 K 线数据
"""
# 美股不支持,抛出异常让 DataFetcherManager 切换到其他数据源
if _is_us_code(stock_code):
raise DataFetchError(f"PytdxFetcher 不支持美股 {stock_code},请使用 AkshareFetcher 或 YfinanceFetcher")
market, code = self._get_market_code(stock_code)
# 计算需要获取的交易日数量(估算)
from datetime import datetime as dt
start_dt = dt.strptime(start_date, '%Y-%m-%d')
end_dt = dt.strptime(end_date, '%Y-%m-%d')
days = (end_dt - start_dt).days
count = min(max(days * 5 // 7 + 10, 30), 800) # 估算交易日,最大 800 条
logger.debug(f"调用 Pytdx get_security_bars(market={market}, code={code}, count={count})")
with self._pytdx_session() as api:
try:
# 获取日 K 线数据
# category: 9-日线, 0-5分钟, 1-15分钟, 2-30分钟, 3-1小时
data = api.get_security_bars(
category=9, # 日线
market=market,
code=code,
start=0, # 从最新开始
count=count
)
if data is None or len(data) == 0:
raise DataFetchError(f"Pytdx 未查询到 {stock_code} 的数据")
# 转换为 DataFrame
df = api.to_df(data)
# 过滤日期范围
df['datetime'] = pd.to_datetime(df['datetime'])
df = df[(df['datetime'] >= start_date) & (df['datetime'] <= end_date)]
return df
except Exception as e:
if isinstance(e, DataFetchError):
raise
raise DataFetchError(f"Pytdx 获取数据失败: {e}") from e
def _normalize_data(self, df: pd.DataFrame, stock_code: str) -> pd.DataFrame:
"""
标准化 Pytdx 数据
Pytdx 返回的列名:
datetime, open, high, low, close, vol, amount
需要映射到标准列名:
date, open, high, low, close, volume, amount, pct_chg
"""
df = df.copy()
# 列名映射
column_mapping = {
'datetime': 'date',
'vol': 'volume',
}
df = df.rename(columns=column_mapping)
# 计算涨跌幅pytdx 不返回涨跌幅,需要自己计算)
if 'pct_chg' not in df.columns and 'close' in df.columns:
df['pct_chg'] = df['close'].pct_change() * 100
df['pct_chg'] = df['pct_chg'].fillna(0).round(2)
# 添加股票代码列
df['code'] = stock_code
# 只保留需要的列
keep_cols = ['code'] + STANDARD_COLUMNS
existing_cols = [col for col in keep_cols if col in df.columns]
df = df[existing_cols]
return df
def get_stock_name(self, stock_code: str) -> Optional[str]:
"""
获取股票名称
Args:
stock_code: 股票代码
Returns:
股票名称,失败返回 None
"""
# 先检查缓存
if stock_code in self._stock_name_cache:
return self._stock_name_cache[stock_code]
try:
market, code = self._get_market_code(stock_code)
with self._pytdx_session() as api:
# 获取股票列表(缓存)
if self._stock_list_cache is None:
# 获取深圳和上海股票列表
sz_stocks = api.get_security_list(0, 0) # 深圳
sh_stocks = api.get_security_list(1, 0) # 上海
self._stock_list_cache = {}
for stock in (sz_stocks or []) + (sh_stocks or []):
self._stock_list_cache[stock['code']] = stock['name']
# 查找股票名称
name = self._stock_list_cache.get(code)
if name:
self._stock_name_cache[stock_code] = name
return name
# 尝试使用 get_finance_info
finance_info = api.get_finance_info(market, code)
if finance_info and 'name' in finance_info:
name = finance_info['name']
self._stock_name_cache[stock_code] = name
return name
except Exception as e:
logger.warning(f"Pytdx 获取股票名称失败 {stock_code}: {e}")
return None
def get_realtime_quote(self, stock_code: str) -> Optional[dict]:
"""
获取实时行情
Args:
stock_code: 股票代码
Returns:
实时行情数据字典,失败返回 None
"""
try:
market, code = self._get_market_code(stock_code)
with self._pytdx_session() as api:
data = api.get_security_quotes([(market, code)])
if data and len(data) > 0:
quote = data[0]
return {
'code': stock_code,
'name': quote.get('name', ''),
'price': quote.get('price', 0),
'open': quote.get('open', 0),
'high': quote.get('high', 0),
'low': quote.get('low', 0),
'pre_close': quote.get('last_close', 0),
'volume': quote.get('vol', 0),
'amount': quote.get('amount', 0),
'bid_prices': [quote.get(f'bid{i}', 0) for i in range(1, 6)],
'ask_prices': [quote.get(f'ask{i}', 0) for i in range(1, 6)],
}
except Exception as e:
logger.warning(f"Pytdx 获取实时行情失败 {stock_code}: {e}")
return None
if __name__ == "__main__":
# 测试代码
logging.basicConfig(level=logging.DEBUG)
fetcher = PytdxFetcher()
try:
# 测试历史数据
df = fetcher.get_daily_data('600519') # 茅台
print(f"获取成功,共 {len(df)} 条数据")
print(df.tail())
# 测试股票名称
name = fetcher.get_stock_name('600519')
print(f"股票名称: {name}")
# 测试实时行情
quote = fetcher.get_realtime_quote('600519')
print(f"实时行情: {quote}")
except Exception as e:
print(f"获取失败: {e}")