You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
347 lines
11 KiB
347 lines
11 KiB
"""
|
|
星耀数智(AmazingData)适配器单元测试
|
|
|
|
测试覆盖:
|
|
- 适配器初始化和配置
|
|
- 连接/断开连接
|
|
- 基础数据获取
|
|
- K线数据获取
|
|
- 财务数据获取
|
|
- 错误处理
|
|
"""
|
|
|
|
import unittest
|
|
import asyncio
|
|
import sys
|
|
import os
|
|
from datetime import date, datetime
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
import pandas as pd
|
|
|
|
# 添加项目根目录到路径
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from app.adapters.amazingdata_adapter import (
|
|
AmazingDataAdapter, AmazingDataConfig, SecurityType, Market, Period
|
|
)
|
|
from app.adapters.base import KLineData, SymbolInfo, TradeCalData
|
|
|
|
|
|
class TestAmazingDataConfig(unittest.TestCase):
|
|
"""测试配置类"""
|
|
|
|
def test_default_config(self):
|
|
"""测试默认配置"""
|
|
config = AmazingDataConfig(
|
|
username='test_user',
|
|
password='test_pass',
|
|
host='localhost',
|
|
port=8080
|
|
)
|
|
self.assertEqual(config.username, 'test_user')
|
|
self.assertEqual(config.local_path, './amazing_data_cache/')
|
|
self.assertTrue(config.use_local_cache)
|
|
|
|
def test_custom_config(self):
|
|
"""测试自定义配置"""
|
|
config = AmazingDataConfig(
|
|
username='user',
|
|
password='pass',
|
|
host='192.168.1.1',
|
|
port=9090,
|
|
local_path='/custom/path/',
|
|
use_local_cache=False
|
|
)
|
|
self.assertEqual(config.port, 9090)
|
|
self.assertEqual(config.local_path, '/custom/path/')
|
|
self.assertFalse(config.use_local_cache)
|
|
|
|
|
|
class TestAmazingDataAdapter(unittest.TestCase):
|
|
"""测试适配器类"""
|
|
|
|
def setUp(self):
|
|
"""测试前准备"""
|
|
self.adapter = AmazingDataAdapter()
|
|
self.test_config = {
|
|
'username': 'test_user',
|
|
'password': 'test_pass',
|
|
'host': 'localhost',
|
|
'port': 8080,
|
|
'local_path': './test_cache/',
|
|
'use_local_cache': False
|
|
}
|
|
|
|
def tearDown(self):
|
|
"""测试后清理"""
|
|
if self.adapter._connected:
|
|
asyncio.get_event_loop().run_until_complete(self.adapter.close())
|
|
|
|
def test_initial_state(self):
|
|
"""测试初始状态"""
|
|
self.assertIsNone(self.adapter.config)
|
|
self.assertIsNone(self.adapter._ad)
|
|
self.assertFalse(self.adapter._is_logged_in)
|
|
self.assertFalse(self.adapter._connected)
|
|
|
|
def test_format_date(self):
|
|
"""测试日期格式化"""
|
|
# 测试整数
|
|
self.assertEqual(self.adapter._format_date(20240101), 20240101)
|
|
# 测试字符串 (带横线)
|
|
self.assertEqual(self.adapter._format_date('2024-01-01'), 20240101)
|
|
# 测试字符串 (带斜杠)
|
|
self.assertEqual(self.adapter._format_date('2024/01/01'), 20240101)
|
|
# 测试纯字符串数字
|
|
self.assertEqual(self.adapter._format_date('20240101'), 20240101)
|
|
# 测试date对象
|
|
self.assertEqual(self.adapter._format_date(date(2024, 1, 1)), 20240101)
|
|
# 测试datetime对象
|
|
self.assertEqual(self.adapter._format_date(datetime(2024, 1, 1)), 20240101)
|
|
|
|
def test_format_date_invalid(self):
|
|
"""测试无效日期格式"""
|
|
with self.assertRaises(ValueError):
|
|
self.adapter._format_date(None)
|
|
with self.assertRaises(ValueError):
|
|
self.adapter._format_date([])
|
|
|
|
def test_check_login_not_logged_in(self):
|
|
"""测试未登录检查"""
|
|
with self.assertRaises(RuntimeError) as context:
|
|
self.adapter._check_login()
|
|
self.assertIn('未连接到数据源', str(context.exception))
|
|
|
|
@patch('app.adapters.amazingdata_adapter.AmazingDataAdapter._check_login')
|
|
def test_check_login_logged_in(self, mock_check):
|
|
"""测试已登录检查"""
|
|
self.adapter._is_logged_in = True
|
|
# 不应该抛出异常
|
|
self.adapter._check_login()
|
|
|
|
|
|
class TestAmazingDataAdapterAsync(unittest.IsolatedAsyncioTestCase):
|
|
"""测试适配器异步方法"""
|
|
|
|
async def asyncSetUp(self):
|
|
"""异步测试前准备"""
|
|
self.adapter = AmazingDataAdapter()
|
|
self.test_config = {
|
|
'username': 'test_user',
|
|
'password': 'test_pass',
|
|
'host': 'localhost',
|
|
'port': 8080
|
|
}
|
|
|
|
async def asyncTearDown(self):
|
|
"""异步测试后清理"""
|
|
if self.adapter._connected:
|
|
await self.adapter.close()
|
|
|
|
@patch('app.adapters.amazingdata_adapter.AmazingDataAdapter._do_login')
|
|
async def test_connect_success(self, mock_login):
|
|
"""测试连接成功"""
|
|
mock_login.return_value = None
|
|
|
|
# Mock AmazingData模块
|
|
mock_ad = Mock()
|
|
mock_ad.BaseData = Mock
|
|
mock_ad.InfoData = Mock
|
|
mock_ad.MarketData = Mock
|
|
mock_ad.constant.Period = Mock()
|
|
|
|
with patch.dict('sys.modules', {'AmazingData': mock_ad}):
|
|
await self.adapter.connect(self.test_config)
|
|
|
|
self.assertTrue(self.adapter._connected)
|
|
self.assertIsNotNone(self.adapter.config)
|
|
|
|
@patch.dict('sys.modules', {'AmazingData': None})
|
|
async def test_connect_import_error(self):
|
|
"""测试SDK未安装"""
|
|
with self.assertRaises(RuntimeError) as context:
|
|
await self.adapter.connect(self.test_config)
|
|
self.assertIn('AmazingData SDK 未安装', str(context.exception))
|
|
|
|
async def test_close_not_connected(self):
|
|
"""测试关闭未连接状态"""
|
|
# 不应该抛出异常
|
|
await self.adapter.close()
|
|
self.assertFalse(self.adapter._is_logged_in)
|
|
|
|
|
|
class TestFetchKlines(unittest.IsolatedAsyncioTestCase):
|
|
"""测试K线数据获取"""
|
|
|
|
async def asyncSetUp(self):
|
|
"""异步测试前准备"""
|
|
self.adapter = AmazingDataAdapter()
|
|
self.adapter._is_logged_in = True
|
|
self.adapter._ad = Mock()
|
|
self.adapter._ad.constant.Period = Mock()
|
|
self.adapter._ad.constant.Period.daily = Mock()
|
|
self.adapter._ad.constant.Period.daily.value = 'daily'
|
|
|
|
# Mock MarketData
|
|
self.mock_market_data = Mock()
|
|
self.adapter._market_data = self.mock_market_data
|
|
|
|
async def test_fetch_klines_empty_result(self):
|
|
"""测试获取空K线数据"""
|
|
self.mock_market_data.query_kline.return_value = {}
|
|
|
|
result = await self.adapter.fetch_klines(
|
|
symbol='000001.SZ',
|
|
start='20240101',
|
|
end='20241231',
|
|
freq='1d'
|
|
)
|
|
|
|
self.assertEqual(result, [])
|
|
|
|
async def test_fetch_klines_with_data(self):
|
|
"""测试获取K线数据"""
|
|
# 创建测试DataFrame
|
|
df = pd.DataFrame({
|
|
'open': [10.0, 11.0],
|
|
'high': [11.0, 12.0],
|
|
'low': [9.0, 10.0],
|
|
'close': [10.5, 11.5],
|
|
'volume': [10000, 20000],
|
|
'amount': [105000, 230000]
|
|
}, index=pd.to_datetime(['2024-01-01', '2024-01-02']))
|
|
|
|
self.mock_market_data.query_kline.return_value = {'000001.SZ': df}
|
|
|
|
result = await self.adapter.fetch_klines(
|
|
symbol='000001.SZ',
|
|
start='20240101',
|
|
end='20240102',
|
|
freq='1d'
|
|
)
|
|
|
|
self.assertEqual(len(result), 2)
|
|
self.assertIsInstance(result[0], KLineData)
|
|
self.assertEqual(result[0].symbol, '000001.SZ')
|
|
self.assertEqual(result[0].open, 10.0)
|
|
|
|
|
|
class TestFetchSymbols(unittest.IsolatedAsyncioTestCase):
|
|
"""测试标的列表获取"""
|
|
|
|
async def asyncSetUp(self):
|
|
"""异步测试前准备"""
|
|
self.adapter = AmazingDataAdapter()
|
|
self.adapter._is_logged_in = True
|
|
self.adapter._ad = Mock()
|
|
|
|
# Mock BaseData
|
|
self.mock_base_data = Mock()
|
|
self.adapter._base_data = self.mock_base_data
|
|
|
|
async def test_fetch_stock_symbols(self):
|
|
"""测试获取股票列表"""
|
|
self.mock_base_data.get_code_list.return_value = [
|
|
'000001.SZ', '600000.SH'
|
|
]
|
|
self.mock_base_data.get_code_info.return_value = pd.DataFrame({
|
|
'symbol': ['平安银行', '浦发银行']
|
|
}, index=['000001.SZ', '600000.SH'])
|
|
|
|
result = await self.adapter.fetch_symbols('stock')
|
|
|
|
self.assertEqual(len(result), 2)
|
|
self.assertIsInstance(result[0], SymbolInfo)
|
|
self.assertEqual(result[0].symbol_id, '000001.SZ')
|
|
self.assertEqual(result[0].exchange, 'SZ')
|
|
|
|
async def test_fetch_futures_symbols(self):
|
|
"""测试获取期货列表"""
|
|
self.mock_base_data.get_future_code_list.return_value = [
|
|
'cu2501', 'al2502'
|
|
]
|
|
|
|
result = await self.adapter.fetch_symbols('futures')
|
|
|
|
self.assertEqual(len(result), 2)
|
|
self.assertEqual(result[0].underlying, 'CU')
|
|
|
|
|
|
class TestTradingCalendar(unittest.IsolatedAsyncioTestCase):
|
|
"""测试交易日历获取"""
|
|
|
|
async def asyncSetUp(self):
|
|
"""异步测试前准备"""
|
|
self.adapter = AmazingDataAdapter()
|
|
self.adapter._is_logged_in = True
|
|
self.adapter._ad = Mock()
|
|
|
|
# Mock BaseData
|
|
self.mock_base_data = Mock()
|
|
self.adapter._base_data = self.mock_base_data
|
|
|
|
async def test_fetch_calendar(self):
|
|
"""测试获取交易日历"""
|
|
self.mock_base_data.get_calendar.return_value = [
|
|
20240101, 20240102, 20240103
|
|
]
|
|
|
|
result = await self.adapter.fetch_trading_calendar(
|
|
exchange='SH',
|
|
start='20240101',
|
|
end='20240103'
|
|
)
|
|
|
|
self.assertEqual(len(result), 3)
|
|
self.assertIsInstance(result[0], TradeCalData)
|
|
|
|
|
|
class TestHealthCheck(unittest.IsolatedAsyncioTestCase):
|
|
"""测试健康检查"""
|
|
|
|
async def asyncSetUp(self):
|
|
"""异步测试前准备"""
|
|
self.adapter = AmazingDataAdapter()
|
|
|
|
async def test_health_check_not_connected(self):
|
|
"""测试未连接时健康检查"""
|
|
result = await self.adapter.health_check()
|
|
self.assertFalse(result)
|
|
|
|
async def test_health_check_connected(self):
|
|
"""测试已连接时健康检查"""
|
|
self.adapter._connected = True
|
|
self.adapter._is_logged_in = True
|
|
self.adapter._base_data = Mock()
|
|
self.adapter._base_data.get_code_list.return_value = ['000001.SZ']
|
|
|
|
result = await self.adapter.health_check()
|
|
self.assertTrue(result)
|
|
|
|
|
|
class TestEnums(unittest.TestCase):
|
|
"""测试枚举类"""
|
|
|
|
def test_security_type_values(self):
|
|
"""测试证券类型枚举值"""
|
|
self.assertEqual(SecurityType.STOCK_A.value, 'EXTRA_STOCK_A')
|
|
self.assertEqual(SecurityType.ETF.value, 'EXTRA_ETF')
|
|
self.assertEqual(SecurityType.FUTURE.value, 'EXTRA_FUTURE')
|
|
|
|
def test_market_values(self):
|
|
"""测试市场枚举值"""
|
|
self.assertEqual(Market.SH.value, 'SH')
|
|
self.assertEqual(Market.SZ.value, 'SZ')
|
|
self.assertEqual(Market.BJ.value, 'BJ')
|
|
|
|
def test_period_values(self):
|
|
"""测试周期枚举值"""
|
|
self.assertEqual(Period.MIN1.value, 'min1')
|
|
self.assertEqual(Period.DAILY.value, 'daily')
|
|
self.assertEqual(Period.WEEKLY.value, 'weekly')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 运行测试
|
|
unittest.main(verbosity=2)
|