# -*- coding: utf-8 -*- """Integration tests for backtest service and repository. These tests run against a temporary SQLite DB (same approach as other tests) and validate idempotency/force semantics, result field correctness, summary creation, and query methods. """ import os import tempfile import unittest from datetime import date, datetime from src.config import Config from src.core.backtest_engine import OVERALL_SENTINEL_CODE from src.services.backtest_service import BacktestService from src.storage import AnalysisHistory, BacktestResult, BacktestSummary, DatabaseManager, StockDaily class BacktestServiceTestCase(unittest.TestCase): def setUp(self) -> None: self._temp_dir = tempfile.TemporaryDirectory() self._db_path = os.path.join(self._temp_dir.name, "test_backtest_service.db") os.environ["DATABASE_PATH"] = self._db_path os.environ["BACKTEST_EVAL_WINDOW_DAYS"] = "3" Config._instance = None DatabaseManager.reset_instance() self.db = DatabaseManager.get_instance() # Ensure analysis is old enough for default min_age_days=14 old_created_at = datetime(2024, 1, 1, 0, 0, 0) with self.db.get_session() as session: session.add( AnalysisHistory( query_id="q1", code="600519", name="贵州茅台", report_type="simple", sentiment_score=80, operation_advice="买入", trend_prediction="看多", analysis_summary="test", stop_loss=95.0, take_profit=110.0, created_at=old_created_at, context_snapshot='{"enhanced_context": {"date": "2024-01-01"}}', ) ) # Analysis day close session.add( StockDaily( code="600519", date=date(2024, 1, 1), open=100.0, high=101.0, low=99.0, close=100.0, ) ) # Forward bars (3 days) that hit take-profit on day1 session.add_all( [ StockDaily(code="600519", date=date(2024, 1, 2), high=111.0, low=100.0, close=105.0), StockDaily(code="600519", date=date(2024, 1, 3), high=108.0, low=103.0, close=106.0), StockDaily(code="600519", date=date(2024, 1, 4), high=109.0, low=104.0, close=107.0), ] ) session.commit() def tearDown(self) -> None: DatabaseManager.reset_instance() self._temp_dir.cleanup() def _count_results(self) -> int: with self.db.get_session() as session: return session.query(BacktestResult).count() def test_force_semantics(self) -> None: service = BacktestService(self.db) stats1 = service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10) self.assertEqual(stats1["saved"], 1) self.assertEqual(self._count_results(), 1) # Non-force should be idempotent stats2 = service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10) self.assertEqual(stats2["saved"], 0) self.assertEqual(self._count_results(), 1) # Force should replace existing result without unique constraint errors stats3 = service.run_backtest(code="600519", force=True, eval_window_days=3, min_age_days=0, limit=10) self.assertEqual(stats3["saved"], 1) self.assertEqual(self._count_results(), 1) def _run_and_get_result(self) -> BacktestResult: """Helper: run backtest and return the single BacktestResult row.""" service = BacktestService(self.db) service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10) with self.db.get_session() as session: return session.query(BacktestResult).one() def test_result_fields_correct(self) -> None: """Verify BacktestResult row contains correct evaluation values.""" result = self._run_and_get_result() self.assertEqual(result.eval_status, "completed") self.assertEqual(result.code, "600519") self.assertEqual(result.analysis_date, date(2024, 1, 1)) self.assertEqual(result.operation_advice, "买入") self.assertEqual(result.position_recommendation, "long") self.assertEqual(result.direction_expected, "up") # Prices self.assertAlmostEqual(result.start_price, 100.0) self.assertAlmostEqual(result.end_close, 107.0) self.assertAlmostEqual(result.stock_return_pct, 7.0) # Direction & outcome self.assertEqual(result.outcome, "win") self.assertTrue(result.direction_correct) # Target hits -- day2 high=111 >= take_profit=110 self.assertTrue(result.hit_take_profit) self.assertFalse(result.hit_stop_loss) self.assertEqual(result.first_hit, "take_profit") self.assertEqual(result.first_hit_trading_days, 1) self.assertEqual(result.first_hit_date, date(2024, 1, 2)) # Simulated execution self.assertAlmostEqual(result.simulated_entry_price, 100.0) self.assertAlmostEqual(result.simulated_exit_price, 110.0) self.assertEqual(result.simulated_exit_reason, "take_profit") self.assertAlmostEqual(result.simulated_return_pct, 10.0) def test_summaries_created_after_run(self) -> None: """Verify both overall and per-stock BacktestSummary rows are created.""" service = BacktestService(self.db) service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10) with self.db.get_session() as session: # Overall summary uses sentinel code overall = session.query(BacktestSummary).filter( BacktestSummary.scope == "overall", BacktestSummary.code == OVERALL_SENTINEL_CODE, ).first() self.assertIsNotNone(overall) self.assertEqual(overall.total_evaluations, 1) self.assertEqual(overall.completed_count, 1) self.assertEqual(overall.win_count, 1) self.assertEqual(overall.loss_count, 0) self.assertAlmostEqual(overall.win_rate_pct, 100.0) # Stock-level summary stock = session.query(BacktestSummary).filter( BacktestSummary.scope == "stock", BacktestSummary.code == "600519", ).first() self.assertIsNotNone(stock) self.assertEqual(stock.total_evaluations, 1) self.assertEqual(stock.completed_count, 1) self.assertEqual(stock.win_count, 1) def test_get_summary_overall_returns_sentinel_as_none(self) -> None: """Verify get_summary translates __overall__ sentinel back to None.""" service = BacktestService(self.db) service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10) summary = service.get_summary(scope="overall", code=None) self.assertIsNotNone(summary) self.assertIsNone(summary["code"]) self.assertEqual(summary["scope"], "overall") self.assertEqual(summary["win_count"], 1) def test_get_recent_evaluations(self) -> None: """Verify get_recent_evaluations returns correct paginated results.""" service = BacktestService(self.db) service.run_backtest(code="600519", force=False, eval_window_days=3, min_age_days=0, limit=10) data = service.get_recent_evaluations(code="600519", limit=10, page=1) self.assertEqual(data["total"], 1) self.assertEqual(data["page"], 1) self.assertEqual(data["limit"], 10) self.assertEqual(len(data["items"]), 1) item = data["items"][0] self.assertEqual(item["code"], "600519") self.assertEqual(item["outcome"], "win") self.assertEqual(item["direction_expected"], "up") self.assertTrue(item["direction_correct"]) def test_multi_stock_summaries(self) -> None: """Verify separate summaries for multiple stocks + correct overall aggregate.""" old_created_at = datetime(2024, 1, 1, 0, 0, 0) with self.db.get_session() as session: # Second stock with sell advice -- price drops (win for cash/down) session.add( AnalysisHistory( query_id="q2", code="000001", name="平安银行", report_type="simple", sentiment_score=30, operation_advice="卖出", trend_prediction="看空", analysis_summary="test2", stop_loss=None, take_profit=None, created_at=old_created_at, context_snapshot='{"enhanced_context": {"date": "2024-01-01"}}', ) ) session.add( StockDaily(code="000001", date=date(2024, 1, 1), open=10.0, high=10.2, low=9.8, close=10.0) ) session.add_all([ StockDaily(code="000001", date=date(2024, 1, 2), high=10.0, low=9.5, close=9.6), StockDaily(code="000001", date=date(2024, 1, 3), high=9.7, low=9.3, close=9.4), StockDaily(code="000001", date=date(2024, 1, 4), high=9.5, low=9.0, close=9.1), ]) session.commit() service = BacktestService(self.db) stats = service.run_backtest(code=None, force=False, eval_window_days=3, min_age_days=0, limit=10) self.assertEqual(stats["saved"], 2) self.assertEqual(stats["completed"], 2) with self.db.get_session() as session: # Each stock has its own summary s1 = session.query(BacktestSummary).filter( BacktestSummary.scope == "stock", BacktestSummary.code == "600519" ).first() s2 = session.query(BacktestSummary).filter( BacktestSummary.scope == "stock", BacktestSummary.code == "000001" ).first() self.assertIsNotNone(s1) self.assertIsNotNone(s2) self.assertEqual(s1.win_count, 1) self.assertEqual(s2.win_count, 1) # Overall aggregates both overall = session.query(BacktestSummary).filter( BacktestSummary.scope == "overall", BacktestSummary.code == OVERALL_SENTINEL_CODE, ).first() self.assertIsNotNone(overall) self.assertEqual(overall.total_evaluations, 2) self.assertEqual(overall.completed_count, 2) self.assertEqual(overall.win_count, 2) if __name__ == "__main__": unittest.main()