You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
buffer_platform/futures_data_collector.py

428 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
期货/股票多周期数据获取与技术指标计算脚本
"""
import akshare as ak
import pandas as pd
import json
import argparse
import os
from datetime import datetime, timedelta
from typing import Dict, List
import warnings
warnings.filterwarnings('ignore')
ak.cache = {}
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
os.makedirs(DATA_DIR, exist_ok=True)
def calculate_ma(df: pd.DataFrame, periods: List[int] = [10, 20]) -> pd.DataFrame:
"""计算移动平均线"""
for period in periods:
df[f'MA{period}'] = df['close'].rolling(window=period, min_periods=1).mean()
return df
def calculate_macd(df: pd.DataFrame, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.DataFrame:
"""计算MACD指标"""
ema_fast = df['close'].ewm(span=fast, adjust=False).mean()
ema_slow = df['close'].ewm(span=slow, adjust=False).mean()
df['macd_dif'] = ema_fast - ema_slow
df['macd_dea'] = df['macd_dif'].ewm(span=signal, adjust=False).mean()
df['macd_histogram'] = (df['macd_dif'] - df['macd_dea']) * 2
df['macd_signal'] = df.apply(lambda row:
'bullish' if row['macd_dif'] > row['macd_dea'] and row['macd_histogram'] > 0
else 'bearish' if row['macd_dif'] < row['macd_dea'] and row['macd_histogram'] < 0
else 'neutral', axis=1)
return df
def get_current_time() -> datetime:
"""获取当前北京时间(去除微秒)"""
return datetime.now().replace(microsecond=0)
def filter_future_data(df: pd.DataFrame, current_time: datetime = None) -> pd.DataFrame:
"""过滤掉未来数据"""
if current_time is None:
current_time = get_current_time()
if 'datetime' not in df.columns:
return df
df['datetime'] = pd.to_datetime(df['datetime'])
original_count = len(df)
df = df[df['datetime'] <= current_time].copy()
filtered_count = original_count - len(df)
if filtered_count > 0:
print(f" 过滤了 {filtered_count} 条未来数据")
return df
def extend_night_session_data(df: pd.DataFrame, symbol: str, period: str) -> pd.DataFrame:
"""尝试获取完整的夜盘数据"""
if df.empty or 'datetime' not in df.columns:
return df
df['datetime'] = pd.to_datetime(df['datetime'])
df = df.sort_values('datetime').reset_index(drop=True)
last_time = df['datetime'].iloc[-1]
last_hour = last_time.hour
last_minute = last_time.minute
is_night_session = (
(last_hour >= 21) or
(last_hour < 2) or
(last_hour == 2 and last_minute <= 30)
)
if not is_night_session:
return df
has_0230 = False
for dt in df['datetime']:
if dt.hour == 2 and dt.minute == 30:
has_0230 = True
break
if has_0230:
return df
print(f" 注意: 夜盘数据可能不完整缺少02:30及之前的数据")
return df
def get_minute_data(symbol: str, period: str) -> pd.DataFrame:
"""获取期货分钟K线数据"""
try:
current_time = get_current_time()
df = ak.futures_zh_minute_sina(symbol=symbol, period=period)
df = df.rename(columns={
'day': 'datetime',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'volume'
})
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
df['datetime'] = pd.to_datetime(df['datetime'])
df = filter_future_data(df, current_time)
df = extend_night_session_data(df, symbol, period)
if len(df) < 50:
print(f" 警告: {period}分钟只获取到{len(df)}根K线建议检查数据源")
return df
except Exception as e:
print(f" 获取{period}分钟数据失败: {e}")
return pd.DataFrame()
def get_daily_data(symbol: str, days: int = 60) -> pd.DataFrame:
"""获取期货日K线数据"""
try:
current_time = get_current_time()
df = ak.futures_zh_daily_sina(symbol=symbol)
df = df.rename(columns={
'date': 'datetime',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'volume'
})
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
df['datetime'] = pd.to_datetime(df['datetime'])
df = df.sort_values('datetime').reset_index(drop=True)
df = filter_future_data(df, current_time)
df = df.tail(days).reset_index(drop=True)
return df
except Exception as e:
print(f" 获取日K数据失败: {e}")
return pd.DataFrame()
def get_stock_minute_data(symbol: str, period: str) -> pd.DataFrame:
"""获取股票分钟K线数据"""
try:
current_time = get_current_time()
if symbol.startswith('6'):
full_symbol = f"sh{symbol}"
else:
full_symbol = f"sz{symbol}"
df = ak.stock_zh_a_minute(symbol=full_symbol, period=period)
df = df.rename(columns={
'day': 'datetime',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'volume'
})
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
df['datetime'] = pd.to_datetime(df['datetime'])
df = filter_future_data(df, current_time)
if len(df) < 50:
print(f" 警告: {period}分钟只获取到{len(df)}根K线建议检查数据源")
return df
except Exception as e:
print(f" 获取{period}分钟数据失败: {e}")
return pd.DataFrame()
def get_stock_daily_data(symbol: str, days: int = 60) -> pd.DataFrame:
"""获取股票日K线数据"""
try:
current_time = get_current_time()
end_date = current_time.strftime('%Y%m%d')
start_date = (current_time - timedelta(days=days*2)).strftime('%Y%m%d')
df = ak.stock_zh_a_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date)
df = df.rename(columns={
'日期': 'datetime',
'开盘': 'open',
'最高': 'high',
'最低': 'low',
'收盘': 'close',
'成交量': 'volume'
})
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
df['datetime'] = pd.to_datetime(df['datetime'])
df = df.sort_values('datetime').reset_index(drop=True)
df = filter_future_data(df, current_time)
df = df.tail(days).reset_index(drop=True)
return df
except Exception as e:
print(f" 获取日K数据失败: {e}")
return pd.DataFrame()
def process_data(df: pd.DataFrame, timeframe: str) -> List[Dict]:
"""处理数据,计算指标并格式化输出"""
if df.empty or len(df) < 10:
return []
df = calculate_ma(df)
df = calculate_macd(df)
candles = []
df_tail = df.tail(50) if len(df) > 50 else df
for _, row in df_tail.iterrows():
candle = {
"time": str(row['datetime']),
"open": round(float(row['open']), 2),
"high": round(float(row['high']), 2),
"low": round(float(row['low']), 2),
"close": round(float(row['close']), 2),
"volume": int(row['volume']) if not pd.isna(row['volume']) else 0,
"ma10": round(float(row['MA10']), 2) if not pd.isna(row.get('MA10')) else None,
"ma20": round(float(row['MA20']), 2) if not pd.isna(row.get('MA20')) else None,
"macd_dif": round(float(row['macd_dif']), 4) if not pd.isna(row.get('macd_dif')) else 0,
"macd_dea": round(float(row['macd_dea']), 4) if not pd.isna(row.get('macd_dea')) else 0,
"macd_histogram": round(float(row['macd_histogram']), 4) if not pd.isna(row.get('macd_histogram')) else 0
}
candles.append(candle)
return candles
def collect_futures_data(symbol: str) -> Dict:
"""收集期货多周期完整数据"""
print(f"\n正在获取期货 {symbol} 的多周期数据...")
print(f"当前时间: {get_current_time().strftime('%Y-%m-%d %H:%M:%S')}")
print("-" * 50)
result = {
"symbol": symbol,
"type": "futures",
"current_price": None,
"timestamp": datetime.now().strftime("%Y-%m-%dT%H:%M:%S+08:00"),
"timeframes": {}
}
periods = [
("60min", "60"),
("30min", "30"),
("15min", "15"),
("5min", "5")
]
for tf_name, tf_period in periods:
print(f"获取 {tf_name} 数据...")
try:
df = get_minute_data(symbol, tf_period)
if not df.empty and len(df) >= 50:
candles = process_data(df, tf_name)
if candles:
result["timeframes"][tf_name] = candles
if result["current_price"] is None:
result["current_price"] = candles[-1]["close"]
print(f" [OK] 成功获取 {len(candles)} 根K线")
else:
print(f" [FAIL] 数据不足或获取失败 (获取到{len(df)}根)")
except Exception as e:
print(f" [ERROR] 错误: {e}")
print("获取 daily 数据...")
try:
df_daily = get_daily_data(symbol, days=60)
if not df_daily.empty and len(df_daily) >= 50:
candles = process_data(df_daily, "daily")
if candles:
result["timeframes"]["daily"] = candles
print(f" [OK] 成功获取 {len(candles)} 根K线")
else:
print(f" [FAIL] 数据不足或获取失败 (获取到{len(df_daily)}根)")
except Exception as e:
print(f" [ERROR] 错误: {e}")
print("-" * 50)
return result
def collect_stock_data(symbol: str) -> Dict:
"""收集股票多周期完整数据"""
print(f"\n正在获取股票 {symbol} 的多周期数据...")
print(f"当前时间: {get_current_time().strftime('%Y-%m-%d %H:%M:%S')}")
print("-" * 50)
result = {
"symbol": symbol,
"type": "stock",
"current_price": None,
"timestamp": datetime.now().strftime("%Y-%m-%dT%H:%M:%S+08:00"),
"timeframes": {}
}
periods = [
("60min", "60"),
("30min", "30"),
("15min", "15"),
("5min", "5")
]
for tf_name, tf_period in periods:
print(f"获取 {tf_name} 数据...")
try:
df = get_stock_minute_data(symbol, tf_period)
if not df.empty and len(df) >= 50:
candles = process_data(df, tf_name)
if candles:
result["timeframes"][tf_name] = candles
if result["current_price"] is None:
result["current_price"] = candles[-1]["close"]
print(f" [OK] 成功获取 {len(candles)} 根K线")
else:
print(f" [FAIL] 数据不足或获取失败 (获取到{len(df)}根)")
except Exception as e:
print(f" [ERROR] 错误: {e}")
print("获取 daily 数据...")
try:
df_daily = get_stock_daily_data(symbol, days=60)
if not df_daily.empty and len(df_daily) >= 50:
candles = process_data(df_daily, "daily")
if candles:
result["timeframes"]["daily"] = candles
print(f" [OK] 成功获取 {len(candles)} 根K线")
else:
print(f" [FAIL] 数据不足或获取失败 (获取到{len(df_daily)}根)")
except Exception as e:
print(f" [ERROR] 错误: {e}")
print("-" * 50)
return result
def main():
parser = argparse.ArgumentParser(description='期货/股票多周期数据获取与技术指标计算')
parser.add_argument('--symbol', type=str, required=True,
help='代码,期货如 SN2504(沪锡), 股票如 000001(平安银行)')
parser.add_argument('--type', type=str, default='futures', choices=['futures', 'stock'],
help='数据类型futures(期货)、stock(股票),默认为 futures')
parser.add_argument('--output', type=str, default=None,
help='输出JSON文件名默认为 代码_时间戳.json')
args = parser.parse_args()
if args.type == 'stock':
data = collect_stock_data(args.symbol)
else:
data = collect_futures_data(args.symbol)
if not data["timeframes"]:
print("\n错误: 未能获取到任何数据,请检查代码是否正确")
if args.type == 'stock':
print("常见股票代码示例:")
print(" 000001 - 平安银行")
print(" 600000 - 浦发银行")
print(" 000858 - 五粮液")
print(" 600519 - 贵州茅台")
else:
print("常见期货合约代码示例:")
print(" SN2504 - 沪锡2504")
print(" AG2506 - 沪银2506")
print(" LC2505 - 碳酸锂2505")
print(" NI2505 - 沪镍2505")
return
print("\n" + "="*60)
print("JSON 输出:")
print("="*60)
json_output = json.dumps(data, ensure_ascii=False, indent=2)
print(json_output)
if args.output:
filename = os.path.join(DATA_DIR, args.output)
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = os.path.join(DATA_DIR, f"{data['symbol']}_{timestamp}.json")
with open(filename, 'w', encoding='utf-8') as f:
f.write(json_output)
print("\n" + "="*60)
print(f"[OK] 数据已保存到: {filename}")
print(f"[OK] 共获取 {len(data['timeframes'])} 个周期数据")
print("="*60)
if __name__ == "__main__":
main()