diff --git a/backtest.py b/backtest.py deleted file mode 100644 index 538315b..0000000 --- a/backtest.py +++ /dev/null @@ -1,283 +0,0 @@ -#!/usr/bin/env python3 -""" -量化回测主程序 - -使用方法: - python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategy.py -""" - -import argparse -import importlib.util -import os -import sys - -import pandas as pd - -# 数据库配置(直接硬编码,开发环境) -DB_HOST = "81.71.3.24" -DB_PORT = 6785 -DB_NAME = "leopard_dev" -DB_USER = "leopard" -DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X" - - -def load_data_from_db(code, start_date, end_date): - """ - 从数据库加载历史数据 - - 参数: - code: 股票代码(如 '000001.SZ') - start_date: 开始日期(如 '2024-01-01') - end_date: 结束日期(如 '2025-12-31') - - 返回: - DataFrame, 包含列: [Open, High, Low, Close, Volume, factor] - """ - import sqlalchemy - import urllib.parse - - # 构建连接字符串(URL 编码密码中的特殊字符) - encoded_password = urllib.parse.quote_plus(DB_PASSWORD) - conn_str = ( - f"postgresql://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}" - ) - engine = sqlalchemy.create_engine(conn_str) - - try: - # 构建 SQL 查询 - query = f""" -SELECT trade_date, - open * factor AS "Open", - close * factor AS "Close", - high * factor AS "High", - low * factor AS "Low", - volume AS "Volume", - COALESCE(factor, 1.0) AS factor -FROM leopard_daily daily - LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id -WHERE stock.code = '{code}' - AND daily.trade_date BETWEEN '{start_date} 00:00:00' - AND '{end_date} 23:59:59' -ORDER BY daily.trade_date - """ - - # 执行查询 - df = pd.read_sql(query, engine) - - if len(df) == 0: - raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据") - - # 潬换日期并设置为索引 - df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d") - df.set_index("trade_date", inplace=True) - - return df - - finally: - # 清理连接 - engine.dispose() - - -def load_strategy(strategy_file): - """ - 动态加载策略文件 - - 参数: - strategy_file: 策略文件路径 (如 'strategy.py' 或 'strategies/macd.py') - - 返回: - (calculate_indicators, strategy_class) 元组 - """ - # 获取模块名 - module_name = strategy_file.replace(".py", "").replace("/", ".") - spec = importlib.util.spec_from_file_location(module_name, strategy_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # 接口验证 - if not hasattr(module, "calculate_indicators"): - raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数") - - if not hasattr(module, "get_strategy"): - raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数") - - calculate_indicators = module.calculate_indicators - strategy_class = module.get_strategy() - - # 验证 get_strategy 返回的是类 - if not isinstance(strategy_class, type): - raise TypeError("get_strategy() 必须返回一个类") - - # 验证策略类继承自 backtesting.Strategy - from backtesting import Strategy - - if not issubclass(strategy_class, Strategy): - raise TypeError("策略类必须继承 backtesting.Strategy") - - return calculate_indicators, strategy_class - - -def apply_color_scheme(): - """ - 应用颜色方案:红涨绿跌(中国股市风格) - """ - import backtesting._plotting as plotting - from bokeh.colors.named import tomato, lime - - plotting.BULL_COLOR = tomato - plotting.BEAR_COLOR = lime - - -def parse_arguments(): - """ - 解析命令行参数 - - 返回: - args: 命名空间对象 - """ - parser = argparse.ArgumentParser(description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter) - - # 必需参数 - parser.add_argument("--code", type=str, required=True, help="股票代码 (如: 000001.SZ)") - parser.add_argument("--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)") - parser.add_argument("--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)") - parser.add_argument("--strategy-file", type=str, required=True, help="策略文件路径 (如: strategy.py)", ) - - # 可选参数 - parser.add_argument("--cash", type=float, default=100000, help="初始资金 (默认: 100000)") - parser.add_argument("--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)") - parser.add_argument("--output", type=str, default=None, help="HTML 输出文件路径 (可选)") - parser.add_argument("--warmup-days", type=int, default=365, help="预热天数 (默认: 365,约一年)") - - return parser.parse_args() - - -def print_stats(stats): - """ - 打印回测统计结果 - - 参数: - stats: backtesting 库返回的统计对象 - """ - print("=" * 60) - print("回测结果") - - indicator_name_mapping = { - # 'Start': '回测开始时间', - # 'End': '回测结束时间', - # 'Duration': '回测持续时长', - # 'Exposure Time [%]': '持仓时间占比(%)', - "Equity Final [$]": "最终收益", - "Equity Peak [$]": "峰值收益", - "Return [%]": "总收益率(%)", - "Buy & Hold Return [%]": "买入并持有收益率(%)", - "Return (Ann.) [%]": "年化收益率(%)", - "Volatility (Ann.) [%]": "年化波动率(%)", - # 'CAGR [%]': '复合年均增长率(%)', - # 'Sharpe Ratio': '夏普比率', - "Sortino Ratio": "索提诺比率", - "Calmar Ratio": "卡尔玛比率", - # 'Alpha [%]': '阿尔法系数(%)', - # 'Beta': '贝塔系数', - "Max. Drawdown [%]": "最大回撤(%)", - "Avg. Drawdown [%]": "平均回撤(%)", - "Max. Drawdown Duration": "最大回撤持续时长", - "Avg. Drawdown Duration": "平均回撤持续时长", - "# Trades": "总交易次数", - "Win Rate [%]": "胜率(%)", - # 'Best Trade [%]': '最佳单笔交易收益率(%)', - # 'Worst Trade [%]': '最差单笔交易收益率(%)', - # 'Avg. Trade [%]': '平均单笔交易收益率(%)', - # 'Max. Trade Duration': '单笔交易最长持有时长', - # 'Avg. Trade Duration': '单笔交易平均持有时长', - # 'Profit Factor': '盈利因子', - # 'Expectancy [%]': '期望收益(%)', - "SQN": "系统质量数", - # 'Kelly Criterion': '凯利准则', - } - for k, v in stats.items(): - if k in indicator_name_mapping: - cn_name = indicator_name_mapping.get(k, k) - if isinstance(v, (int, float)): - if "%" in cn_name or k in [ - "Sharpe Ratio", - "Sortino Ratio", - "Calmar Ratio", - "Profit Factor", - ]: - formatted_value = f"{v:.2f}" - elif "$" in cn_name: - formatted_value = f"{v:.2f}" - elif "次数" in cn_name: - formatted_value = f"{v:.0f}" - else: - formatted_value = f"{v:.4f}" - else: - formatted_value = str(v) - print(f"{cn_name}: {formatted_value}") - - print("=" * 60) - - -def main(): - """ - 主函数:编排完整回测流程 - """ - try: - # 解析参数 - args = parse_arguments() - - apply_color_scheme() - - # 计算预热开始日期(回测开始日期往前推 warmup_days 天) - warmup_start_date = ( - pd.to_datetime(args.start_date) - pd.Timedelta(days=args.warmup_days) - ).strftime("%Y-%m-%d") - - # 加载数据(包含预热期间) - print(f"加载股票数据(含预热): {args.code} ({warmup_start_date} ~ {args.end_date}),预热天数: {args.warmup_days}") - data = load_data_from_db(args.code, warmup_start_date, args.end_date) - print(f"数据加载完成,共 {len(data)} 条记录(含预热)") - - # 加载策略 - calculate_indicators, strategy_class = load_strategy(args.strategy_file) - - # 计算指标(在扩展数据上计算,确保长周期指标有足够历史数据) - data = calculate_indicators(data) - print(f"指标计算完成") - - # 截取回测期间的数据(去掉预热期间) - data = data.loc[args.start_date : args.end_date] - print(f"回测数据范围: {args.start_date} ~ {args.end_date},共 {len(data)} 条记录") - - # 执行回测 - from backtesting import Backtest - - bt = Backtest( - data, - strategy_class, - cash=args.cash, - commission=args.commission, - finalize_trades=True, - ) - stats = bt.run() - - # 输出结果 - print_stats(stats) - - # 生成图表 - if args.output: - os.makedirs(os.path.dirname(args.output), exist_ok=True) - bt.plot(filename=args.output, open_browser=False) - print(f"图表已保存到: {args.output}") - - except Exception as e: - print(f"\n错误: {e}") - import traceback - - traceback.print_exc() - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/backtest_command.py b/backtest_command.py new file mode 100755 index 0000000..3989dc7 --- /dev/null +++ b/backtest_command.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +import argparse +import sys + +import tabulate + +import backtest_core + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter) + + parser.add_argument("--codes", type=str, nargs="+", required=True, help="股票代码列表 (如: 000001.SZ 600000.SH)", ) + parser.add_argument("--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)", ) + parser.add_argument("--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)", ) + parser.add_argument("--strategy-file", type=str, required=True, help="策略文件路径 (如: strategy.py)", ) + parser.add_argument("--cash", type=float, default=100000, help="初始资金 (默认: 100000)", ) + parser.add_argument("--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)", ) + parser.add_argument("--warmup-days", type=int, default=365, help="预热天数 (默认: 365,约一年)", ) + parser.add_argument("--output-dir", type=str, default=None, help="HTML 图表输出目录 (可选,为每个股票生成 {code}.html)", ) + + return parser.parse_args() + + +def format_single_result(result: backtest_core.BacktestResult): + print("=" * 60) + print(f"股票代码: {result.code}") + print("=" * 60) + + indicator_mapping = { + "最终收益": f"{result.equity_final:.2f}", + "峰值收益": f"{result.equity_peak:.2f}", + "总收益率(%)": f"{result.return_pct:.2f}", + "买入并持有收益率(%)": f"{result.buy_hold_return_pct:.2f}", + "年化收益率(%)": f"{result.return_ann_pct:.2f}", + "年化波动率(%)": f"{result.volatility_ann_pct:.2f}", + "索提诺比率": f"{result.sortino_ratio:.2f}", + "卡尔玛比率": f"{result.calmar_ratio:.2f}", + "最大回撤(%)": f"{result.max_drawdown_pct:.2f}", + "平均回撤(%)": f"{result.avg_drawdown_pct:.2f}", + "最大回撤持续时长": f"{result.max_drawdown_duration:.0f} 天", + "平均回撤持续时长": f"{result.avg_drawdown_duration:.0f} 天", + "总交易次数": f"{result.num_trades:.0f}", + "胜率(%)": f"{result.win_rate_pct:.2f}", + "系统质量数": f"{result.sqn:.2f}", + } + + for name, value in indicator_mapping.items(): + print(f"{name}: {value}") + + print("=" * 60) + + +def format_batch_results(results: list[backtest_core.BacktestResult]): + table_data = [] + for result in results: + table_data.append( + [ + result.code, + f"{result.return_pct:.2f}", + f"{result.buy_hold_return_pct:.2f}", + f"{result.return_ann_pct:.2f}", + f"{result.volatility_ann_pct:.2f}", + f"{result.win_rate_pct:.2f}", + f"{result.max_drawdown_pct:.2f}", + f"{result.sortino_ratio:.2f}", + f"{result.num_trades:.0f}", + f"{result.sqn:.2f}", + ] + ) + + headers = [ + "股票代码", + "收益率%", + "买入持有%", + "年化收益%", + "年化波动%", + "胜率%", + "最大回撤%", + "索提诺比率", + "交易次数", + "SQN", + ] + print(tabulate.tabulate(table_data, headers=headers, tablefmt="grid")) + + +def main(): + args = parse_arguments() + + try: + results = backtest_core.run_batch_backtest( + codes=args.codes, + start_date=args.start_date, + end_date=args.end_date, + strategy_file=args.strategy_file, + cash=args.cash, + commission=args.commission, + warmup_days=args.warmup_days, + output_dir=args.output_dir, + show_progress=True, + ) + + if len(results) == 1: + format_single_result(results[0]) + else: + format_batch_results(results) + + if args.output_dir: + print(f"\n图表已保存到: {args.output_dir}/") + + except Exception as e: + print(f"\n错误: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/backtest_core.py b/backtest_core.py new file mode 100644 index 0000000..ae0ac43 --- /dev/null +++ b/backtest_core.py @@ -0,0 +1,208 @@ +import dataclasses +import importlib.util +import os +from typing import Optional + +import pandas as pd +from tqdm import tqdm + +import config + + +@dataclasses.dataclass +class BacktestResult: + code: str + equity_final: float + equity_peak: float + return_pct: float + buy_hold_return_pct: float + return_ann_pct: float + volatility_ann_pct: float + sortino_ratio: float + calmar_ratio: float + max_drawdown_pct: float + avg_drawdown_pct: float + max_drawdown_duration: float + avg_drawdown_duration: float + num_trades: int + win_rate_pct: float + sqn: float + + +def load_data_from_db(code: str, start_date: str, end_date: str) -> pd.DataFrame: + import sqlalchemy + import urllib.parse + + encoded_password = urllib.parse.quote_plus(config.DB_PASSWORD) + conn_str = f"postgresql://{config.DB_USER}:{encoded_password}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}" + engine = sqlalchemy.create_engine(conn_str) + + try: + query = f""" +SELECT trade_date, + open * factor AS "Open", + close * factor AS "Close", + high * factor AS "High", + low * factor AS "Low", + volume AS "Volume", + COALESCE(factor, 1.0) AS factor +FROM leopard_daily daily + LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id +WHERE stock.code = '{code}' + AND daily.trade_date BETWEEN '{start_date} 00:00:00' + AND '{end_date} 23:59:59' +ORDER BY daily.trade_date + """ + + df = pd.read_sql(query, engine) + + if len(df) == 0: + raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据") + + df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d") + df.set_index("trade_date", inplace=True) + + return df + + finally: + engine.dispose() + + +def load_strategy(strategy_file: str): + module_name = strategy_file.replace(".py", "").replace("/", ".") + spec = importlib.util.spec_from_file_location(module_name, strategy_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if not hasattr(module, "calculate_indicators"): + raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数") + + if not hasattr(module, "get_strategy"): + raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数") + + calculate_indicators = module.calculate_indicators + strategy_class = module.get_strategy() + + if not isinstance(strategy_class, type): + raise TypeError("get_strategy() 必须返回一个类") + + from backtesting import Strategy + + if not issubclass(strategy_class, Strategy): + raise TypeError("策略类必须继承 backtesting.Strategy") + + return calculate_indicators, strategy_class + + +def apply_color_scheme(): + import backtesting._plotting as plotting + + plotting.BULL_COLOR = config.BULL_COLOR + plotting.BEAR_COLOR = config.BEAR_COLOR + + +def run_backtest( + code: str, + start_date: str, + end_date: str, + strategy_file: str, + cash: float = config.DEFAULT_CASH, + commission: float = config.DEFAULT_COMMISSION, + warmup_days: int = config.DEFAULT_WARMUP_DAYS, + output_dir: Optional[str] = None, +) -> BacktestResult: + warmup_start_date = (pd.to_datetime(start_date) - pd.Timedelta(days=warmup_days)).strftime("%Y-%m-%d") + + data = load_data_from_db(code, warmup_start_date, end_date) + + calculate_indicators, strategy_class = load_strategy(strategy_file) + + data = calculate_indicators(data) + + data = data.loc[start_date:end_date] + + from backtesting import Backtest + + bt = Backtest(data, strategy_class, cash=cash, commission=commission, finalize_trades=True) + stats = bt.run() + + apply_color_scheme() + + if output_dir: + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{code}.html") + bt.plot(filename=output_path, open_browser=False) + + def _safe_float(value, default=0): + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + def _safe_int(value, default=0): + if value is None: + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + def _safe_timedelta(value, default=0): + if value is None: + return default + try: + return float(value.total_seconds() / 86400) + except (TypeError, AttributeError): + return default + + return BacktestResult( + code=code, + equity_final=_safe_float(stats.get("Equity Final [$]"), 0), + equity_peak=_safe_float(stats.get("Equity Peak [$]"), 0), + return_pct=_safe_float(stats.get("Return [%]"), 0), + buy_hold_return_pct=_safe_float(stats.get("Buy & Hold Return [%]"), 0), + return_ann_pct=_safe_float(stats.get("Return (Ann.) [%]"), 0), + volatility_ann_pct=_safe_float(stats.get("Volatility (Ann.) [%]"), 0), + sortino_ratio=_safe_float(stats.get("Sortino Ratio"), 0), + calmar_ratio=_safe_float(stats.get("Calmar Ratio"), 0), + max_drawdown_pct=_safe_float(stats.get("Max. Drawdown [%]"), 0), + avg_drawdown_pct=_safe_float(stats.get("Avg. Drawdown [%]"), 0), + max_drawdown_duration=_safe_timedelta(stats.get("Max. Drawdown Duration"), 0), + avg_drawdown_duration=_safe_timedelta(stats.get("Avg. Drawdown Duration"), 0), + num_trades=_safe_int(stats.get("# Trades"), 0), + win_rate_pct=_safe_float(stats.get("Win Rate [%]"), 0), + sqn=_safe_float(stats.get("SQN"), 0), + ) + + +def run_batch_backtest( + codes: list[str], + start_date: str, + end_date: str, + strategy_file: str, + cash: float = config.DEFAULT_CASH, + commission: float = config.DEFAULT_COMMISSION, + warmup_days: int = config.DEFAULT_WARMUP_DAYS, + output_dir: Optional[str] = None, + show_progress: bool = True, +) -> list[BacktestResult]: + results = [] + + codes_iter = tqdm(codes, desc="批量回测") if show_progress else codes + + for code in codes_iter: + result = run_backtest( + code=code, + start_date=start_date, + end_date=end_date, + strategy_file=strategy_file, + cash=cash, + commission=commission, + warmup_days=warmup_days, + output_dir=output_dir, + ) + results.append(result) + + return results diff --git a/config.py b/config.py new file mode 100644 index 0000000..ef71359 --- /dev/null +++ b/config.py @@ -0,0 +1,20 @@ +""" +配置文件 + +集中管理数据库配置、默认回测参数、图表配色 +""" + +DB_HOST = "81.71.3.24" +DB_PORT = 6785 +DB_NAME = "leopard_dev" +DB_USER = "leopard" +DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X" + +DEFAULT_CASH = 100000 +DEFAULT_COMMISSION = 0.002 +DEFAULT_WARMUP_DAYS = 365 + +from bokeh.colors.named import tomato, lime + +BULL_COLOR = tomato +BEAR_COLOR = lime diff --git a/note_refactor.md b/note_refactor.md new file mode 100644 index 0000000..8041b10 --- /dev/null +++ b/note_refactor.md @@ -0,0 +1,295 @@ +# 回测代码重构说明 + +## 概述 + +本次重构将原有的单一文件 `backtest.py` 拆分为模块化架构,提升代码复用性和可维护性。 + +## 文件结构变化 + +### 新增文件 + +1. **config.py** - 配置管理模块 + - 数据库配置(DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD) + - 默认回测参数(DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS) + - 图表配色(BULL_COLOR, BEAR_COLOR) + +2. **backtest_core.py** - 核心回测引擎 + - `BacktestResult` 数据类:结构化回测结果 + - `load_data_from_db()`:从数据库加载历史数据 + - `load_strategy()`:动态加载策略文件 + - `apply_color_scheme()`:应用图表配色 + - `run_backtest()`:单股票回测函数 + - `run_batch_backtest()`:批量回测函数(串行执行) + +3. **backtest_command.py** - 命令行界面 + - `parse_arguments()`:解析命令行参数 + - `format_single_result()`:详细格式输出(单股票) + - `format_batch_results()`:表格格式输出(多股票,使用 tabulate) + - `main()`:主流程编排 + +### 删除文件 + +1. **backtest.py** - 原有单一文件(284 行) + +## 接口变化 + +### 新增 API + +```python +# 单股票回测 +result = backtest_core.run_backtest( + code='000001.SZ', + start_date='2024-01-01', + end_date='2024-12-31', + strategy_file='strategies/sma_strategy.py', + cash=100000, + commission=0.002, + warmup_days=365, + output_dir=None # 可选,为 None 时不生成图表 +) + +# 批量回测 +results = backtest_core.run_batch_backtest( + codes=['000001.SZ', '600000.SH'], + start_date='2024-01-01', + end_date='2024-12-31', + strategy_file='strategies/sma_strategy.py', + cash=100000, + commission=0.002, + warmup_days=365, + output_dir='output/', # 可选,为每个股票生成 {code}.html + show_progress=True # 可选,是否显示 tqdm 进度条 +) +``` + +### 新增数据结构 + +```python +@dataclasses.dataclass +class BacktestResult: + code: str + equity_final: float + equity_peak: float + return_pct: float + buy_hold_return_pct: float + return_ann_pct: float + volatility_ann_pct: float + sortino_ratio: float + calmar_ratio: float + max_drawdown_pct: float + avg_drawdown_pct: float + max_drawdown_duration: float + avg_drawdown_duration: float + num_trades: int + win_rate_pct: float + sqn: float +``` + +## 命令行使用方式变化 + +### 旧方式(已删除) + +```bash +python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2024-12-31 --strategy-file strategy.py +``` + +### 新方式 + +```bash +uv run python backtest_command.py --codes 000001.SZ --start-date 2024-01-01 --end-date 2024-12-31 --strategy-file strategies/sma_strategy.py +``` + +### 参数变化 + +| 参数名 | 变化 | 说明 | +|--------|--------|------| +| `--code` | 改为 `--codes` | 从单一参数改为多值参数(`nargs='+'`) | +| `--output` | 改为 `--output-dir` | 指定目录而非文件路径 | + +### 新增参数 + +- `--output-dir`:指定图表输出目录(可选) + - 单股票时:生成 `{code}.html` 在指定目录 + - 多股票时:为每个股票生成 `{code}.html` 在指定目录 + - 不指定时不生成图表 + +## 输出格式变化 + +### 单股票输出 + +保持原有的详细格式输出,每个指标单独一行: + +``` +============================================================ +股票代码: 000001.SZ +============================================================ +最终收益: 100981.58 +峰值收益: 103731.54 +总收益率(%): 0.98 +... +============================================================ +``` + +### 多股票输出 + +新增表格格式输出(使用 tabulate,grid 格式): + +``` ++------------+-----------+---------+-------------+------------+-------+ +| 股票代码 | 收益率% | 胜率% | 最大回撤% | 交易次数 | SQN | ++============+===========+=========+=============+============+=======+ +| 000001.SZ | 0.98 | 100 | -2.65 | 1 | nan | +| 600000.SH | 0.04 | 100 | -1.5 | 1 | nan | ++------------+-----------+---------+-------------+------------+-------+ +``` + +### 进度条 + +多股票回测时显示 tqdm 进度条: + +``` +批量回测: 50%|█████ | 1/2 [00:07<00:07, 7.82s/it] +``` + +## 依赖变化 + +### 新增依赖 + +- `tabulate`:表格格式化 + - 版本:0.9.0 + - 用途:批量回测结果的表格化输出 + +- `tqdm`:进度条显示 + - 版本:4.67.1 + - 用途:批量回测时的实时进度反馈 + +## 特性增强 + +### 新增功能 + +1. **批量回测**:支持传入多个股票代码进行串行回测 + - 命令:`--codes 000001.SZ 600000.SH` + - 输出:表格化结果对比 + - 进度条:实时显示回测进度 + +2. **图表生成**:为每个股票生成独立 HTML 图表 + - 参数:`--output-dir output/` + - 输出:`{code}.html` 在指定目录 + - 自动创建目录:`os.makedirs(output_dir, exist_ok=True)` + +3. **进度条显示**:使用 tqdm 提供实时反馈 + - 多股票时自动显示 + - 可通过 `show_progress=False` 禁用 + +## 兼容性说明 + +### BREAKING CHANGES + +1. **命令行入口变化** + - 旧:`python backtest.py` + - 新:`uv run python backtest_command.py` + +2. **参数名称变化** + - `--code` → `--codes`(从单值改为多值) + +### 兼容性保证 + +- 所有原有功能完整保留 +- 核心回测逻辑无变化 +- 策略加载方式不变 +- 数据访问接口不变 + +## 代码行数对比 + +| 文件 | 旧行数 | 新行数 | 变化 | +|------|---------|---------|------| +| backtest.py | 284 | - | -284 | +| config.py | - | 20 | +20 | +| backtest_core.py | - | ~200 | +200 | +| backtest_command.py | - | ~120 | +120 | +| **总计** | **284** | **~340** | **+56** | + +## 迁移指南 + +### 对于开发者 + +如果需要在其他模块中调用回测功能: + +```python +from backtest_core import run_backtest, run_batch_backtest, BacktestResult + +# 单股票回测 +result = run_backtest( + code='000001.SZ', + start_date='2024-01-01', + end_date='2024-12-31', + strategy_file='strategies/sma_strategy.py' +) + +# 批量回测 +results = run_batch_backtest( + codes=['000001.SZ', '600000.SH'], + start_date='2024-01-01', + end_date='2024-12-31', + strategy_file='strategies/sma_strategy.py' +) + +# 访问结果 +print(result.return_pct) +print(result.win_rate_pct) +``` + +### 对于终端用户 + +**单股票回测示例:** + +```bash +uv run python backtest_command.py \ + --codes 000001.SZ \ + --start-date 2024-01-01 \ + --end-date 2024-12-31 \ + --strategy-file strategies/sma_strategy.py +``` + +**多股票回测示例:** + +```bash +uv run python backtest_command.py \ + --codes 000001.SZ 600000.SH \ + --start-date 2024-01-01 \ + --end-date 2024-12-31 \ + --strategy-file strategies/sma_strategy.py +``` + +**生成图表示例:** + +```bash +uv run python backtest_command.py \ + --codes 000001.SZ \ + --start-date 2024-01-01 \ + --end-date 2024-12-31 \ + --strategy-file strategies/sma_strategy.py \ + --output-dir output/ +``` + +## 错误处理 + +- **立即失败策略**:遇到第一个错误立即停止,不继续执行其他股票 +- **友好错误提示**:捕获异常并打印清晰的错误信息 +- **退出状态码**:成功返回 0,失败返回非零 +- **回溯信息**:打印完整的堆栈跟踪以便调试 + +## 性能考虑 + +- **串行执行**:当前采用串行执行,确保简单可靠 +- **未来扩展**:未来可改为并行执行(ThreadPoolExecutor)以提升性能 +- **数据加载**:每次回测创建独立的数据库连接,避免连接池复杂度 + +## 总结 + +本次重构实现了: +- ✅ 代码模块化:核心逻辑与 CLI 界面分离 +- ✅ 可复用性:提供标准化 API 供其他模块调用 +- ✅ 功能增强:支持批量回测和图表生成 +- ✅ 用户体验:表格化结果和进度条显示 +- ✅ 代码质量:更清晰的模块划分和类型提示 diff --git a/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/.openspec.yaml b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/.openspec.yaml new file mode 100644 index 0000000..df18424 --- /dev/null +++ b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-01-28 diff --git a/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/design.md b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/design.md new file mode 100644 index 0000000..04ad2c3 --- /dev/null +++ b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/design.md @@ -0,0 +1,287 @@ +## Context + +**Current State:** +`backtest.py` (284 lines) 是单一文件,包含: +- 命令行参数解析 +- 数据库连接和数据加载 +- 策略动态加载和验证 +- 回测执行逻辑 +- 结果格式化输出 +- 图表生成 + +**Constraints:** +- 需要保持现有功能完整性(数据加载、策略加载、回测执行、结果展示) +- 需要支持多股票回测(串行执行) +- 不考虑并发实现(保持简单) +- 错误处理采用立即失败策略 +- 数据库配置明文存储,不考虑环境变量 + +**Stakeholders:** +- 开发者:需要清晰的模块划分和可复用的接口 +- 终端用户:需要友好的 CLI 输出(进度条、表格化结果) + +## Goals / Non-Goals + +**Goals:** +1. 分离核心逻辑与 CLI 界面,提升代码复用性 +2. 提供标准化函数接口,供其他模块调用回测功能 +3. 支持多股票批量回测(串行执行) +4. 集中管理配置(数据库、参数、配色) +5. 优化 CLI 输出体验(tabulate 表格化、tqdm 进度条) + +**Non-Goals:** +1. 并行执行多股票回测(性能优化非目标) +2. 环境变量管理配置(配置明文存储即可) +3. 复杂的聚合统计(仅单股票结果拼接) +4. 图表文件合并(每个股票生成独立 HTML) +5. 配置文件热重载(启动时加载一次) + +## Decisions + +### Decision 1: 三层模块划分 +**选择:** 分离为 `config.py`、`backtest_core.py`、`backtest_command.py` 三个文件 + +**理由:** +- **config.py**:集中管理所有配置,避免硬编码分散 +- **backtest_core.py**:纯粹的业务逻辑,提供可复用的函数接口 +- **backtest_command.py**:CLI 界面,负责参数解析和结果展示 + +**替代方案:** +- 方案 A:保留单一文件,但改进内部结构(函数分离) + - 拒绝理由:仍无法复用,CLI 和业务逻辑耦合 +- 方案 B:使用类封装(如 `BacktestEngine` 类) + - 拒绝理由:增加复杂度,函数接口已足够 + +### Decision 2: BacktestResult 数据类 +**选择:** 使用 `dataclasses.dataclass` 定义 `BacktestResult` + +**理由:** +- 结构化返回结果,便于序列化和导出 +- 类型提示支持,提升代码可读性 +- 自动生成 `__init__`、`__repr__` 等方法,减少样板代码 + +**替代方案:** +- 方案 A:直接返回原始 `stats` 对象(backtesting 库返回) + - 拒绝理由:依赖 backtesting 库内部结构,耦合度高 +- 方案 B:返回字典 + - 拒绝理由:缺乏类型提示,容易拼写错误 + +### Decision 3: 批量回测策略 +**选择:** 串行执行(`for` 循环),立即失败 + +**理由:** +- 简单可靠,易于调试 +- 错误处理清晰(第一个失败就停止) +- 避免并发带来的资源竞争和复杂度 + +**替代方案:** +- 方案 A:并行执行(ThreadPoolExecutor) + - 拒绝理由:性能非目标,并发增加复杂度 +- 方案 B:继续执行其他股票,最后统一报告错误 + - 拒绝理由:用户需求是立即失败 + +### Decision 4: CLI 参数设计 +**选择:** `--codes` 多值参数(`nargs='+'`),`--output-dir` 目录参数 + +**理由:** +- `--codes` 支持传入多个股票代码,如 `--codes 000001.SZ 600000.SH` +- `--output-dir` 为每个股票生成 `{code}.html`,如 `output/000001.SZ.html` +- 保持原有参数(`--start-date`、`--end-date`、`--strategy-file`、`--cash`、`--commission`、`--warmup-days`) + +**替代方案:** +- 方案 A:`--code` 逗号分隔(如 `--code 000001.SZ,600000.SH`) + - 拒绝理由:需要额外解析逻辑,不直观 +- 方案 B:`--code` 多次调用(如 `--code 000001.SZ --code 600000.SH`) + - 拒绝理由:argparse 的 `nargs='+'` 更符合习惯 + +### Decision 5: 输出优化库 +**选择:** 使用 `tabulate` 表格化批量结果,使用 `tqdm` 显示进度条 + +**理由:** +- **tabulate**:提供美观的表格输出,支持多种格式(grid、simple 等) +- **tqdm**:提供实时进度条,提升用户体验 +- 两个库都是轻量级,不引入复杂依赖 + +**替代方案:** +- 方案 A:手动格式化表格(字符串拼接) + - 拒绝理由:代码冗余,格式不够美观 +- 方案 B:不使用进度条(仅输出完成提示) + - 拒绝理由:多股票回测耗时较长,用户需要进度反馈 + +### Decision 6: 结果展示策略 +**选择:** 单股票使用详细格式(现有),多股票使用表格格式(新增) + +**理由:** +- 单股票:保持原有的详细输出(每个指标单独一行) +- 多股票:使用 `tabulate` 表格横向对比,节省垂直空间 + +**替代方案:** +- 方案 A:所有情况都使用详细格式(拼接) + - 拒绝理由:多股票时输出过长,难以阅读 +- 方案 B:所有情况都使用表格格式 + - 拒绝理由:单股票时表格优势不明显,详细格式更清晰 + +### Decision 7: 配置管理方式 +**选择:** 明文常量存储在 `config.py` + +**理由:** +- 满足用户需求(不考虑信息安全) +- 避免引入 `python-dotenv` 依赖 +- 代码简洁,修改直接 + +**替代方案:** +- 方案 A:环境变量(`os.getenv`) + - 拒绝理由:用户明确不需要 +- 方案 B:配置文件(JSON/YAML) + - 拒绝理由:增加文件管理和解析复杂度 + +### Decision 8: 数据访问接口 +**选择:** `load_data_from_db(code, start_date, end_date)` 函数签名保持不变 + +**理由:** +- 现有接口已满足需求(单次查询一个股票) +- 迁移成本低,直接复制到 `backtest_core.py` + +**替代方案:** +- 方案 A:批量查询(`load_data_from_db(codes, start_date, end_date)`) + - 拒绝理由:需要修改 SQL 为 `IN` 子句,且结果聚合复杂 +- 方案 B:连接池复用(全局 engine 对象) + - 拒绝理由:每次创建引擎的开销可接受(串行执行) + +### Decision 9: 策略加载接口 +**选择:** `load_strategy(strategy_file)` 返回 `(calculate_indicators, strategy_class)` 元组 + +**理由:** +- 保持现有接口,迁移成本低 +- 函数返回两个值符合 Python 惯例 + +**替代方案:** +- 方案 A:返回类对象(策略类自带指标计算方法) + - 拒绝理由:现有策略文件结构分离了两者,修改成本高 +- 方案 B:返回命名空间对象(封装两个属性) + - 拒绝理由:增加复杂度,元组足够 + +### Decision 10: 错误处理策略 +**选择:** 立即失败(不捕获部分错误继续执行) + +**理由:** +- 符合用户需求 +- 简化错误追踪(第一个错误直接暴露) +- 避免"部分成功"的歧义状态 + +**替代方案:** +- 方案 A:捕获错误但继续执行,最后统一报告 + - 拒绝理由:用户明确要求立即失败 + +## Risks / Trade-offs + +### Risk 1: CLI 命令变化导致用户习惯中断 +**风险:** 用户习惯使用 `python backtest.py`,需要切换到 `uv run python backtest_command.py` + +**缓解:** +- 在项目根目录创建软链接 `backtest.py -> backtest_command.py`(可选) +- 或在 README 中明确说明新的使用方式 +- 提供迁移指南(参数变化说明) + +### Risk 2: 多股票串行执行耗时较长 +**风险:** 10 个股票可能需要 10 倍时间(每个 30 秒 → 总计 5 分钟) + +**缓解:** +- 使用 `tqdm` 进度条提供实时反馈 +- 在 README 中说明性能限制 +- 未来可扩展为并行执行(非当前目标) + +### Risk 3: BacktestResult 字段可能与 backtesting 库不兼容 +**风险:** backtesting 库升级后,stats 对象的键名可能变化 + +**缓解:** +- 使用 `.get(key, default)` 方法访问,避免 KeyError +- 提供默认值(0 或空字符串) +- 在文档中说明依赖的 backtesting 版本 + +### Risk 4: tabulate/tqdm 依赖未安装 +**风险:** 用户运行时缺少依赖,导致 ImportError + +**缓解:** +- 使用 `uv add` 明确添加依赖到 pyproject.toml +- 在 README 中说明安装步骤 +- 错误信息中提示安装命令(`uv add tabulate tqdm`) + +### Risk 5: 策略文件路径处理不一致 +**风险:** 策略文件路径可能是相对路径或绝对路径,导致加载失败 + +**缓解:** +- 使用 `os.path.abspath()` 转换为绝对路径 +- 在错误信息中提示用户检查路径 +- 测试相对路径和绝对路径两种情况 + +### Risk 6: 图表输出目录不存在 +**风险:** 用户指定的 `--output-dir` 不存在,导致保存失败 + +**缓解:** +- 使用 `os.makedirs(output_dir, exist_ok=True)` 自动创建 +- 在错误信息中提示用户检查目录权限 + +### Risk 7: 内存占用(多股票同时加载数据) +**风险:** 如果同时加载多个股票数据,内存占用可能较高 + +**缓解:** +- 串行执行确保一次只加载一个股票的数据 +- 单个股票的数据量可控(10 年约几 MB) +- future 可考虑流式处理(非当前目标) + +## Migration Plan + +### Step 1: 创建 config.py +1. 从 `backtest.py` 提取数据库配置 +2. 添加默认回测参数 +3. 添加图表配色配置 +4. 测试导入无错误 + +### Step 2: 创建 backtest_core.py +1. 迁移 `load_data_from_db()` 函数(导入 config) +2. 迁移 `load_strategy()` 函数 +3. 迁移 `apply_color_scheme()` 函数(使用 config 配置) +4. 定义 `BacktestResult` 数据类 +5. 实现 `run_backtest()` 函数 +6. 实现 `run_batch_backtest()` 函数 +7. 单元测试核心函数 + +### Step 3: 创建 backtest_command.py +1. 实现 `parse_arguments()` 函数(支持 `--codes`) +2. 实现 `format_single_result()` 函数(详细格式) +3. 实现 `format_batch_results()` 函数(使用 tabulate) +4. 实现 `main()` 函数(调用 `run_batch_backtest()`) +5. 测试单股票回测 +6. 测试多股票回测 + +### Step 4: 更新依赖 +1. 运行 `uv add tabulate` 添加依赖 +2. 运行 `uv add tqdm` 添加依赖 +3. 运行 `uv sync` 同步依赖 + +### Step 5: 删除 backtest.py +1. 确认新功能完整(单股票、多股票、图表输出) +2. 确认错误处理正确(立即失败) +3. 删除 `backtest.py` 文件 +4. 更新 README 说明新的使用方式 + +### Rollback Strategy +如果迁移过程中发现问题: +1. 保留 `backtest.py` 直到 `backtest_command.py` 完全可用 +2. 使用 `git` 版本控制,可随时回退 +3. 逐步迁移(先核心函数,后 CLI),确保每步可验证 + +## Open Questions + +1. **BacktestResult 字段完整性:** 是否需要包含所有 backtesting.stats 的键,或仅包含当前用到的字段? + - 倾向:仅包含当前用到的字段(未来可扩展) + +2. **表格格式选择:** tabulate 支持多种格式(grid、simple、pipe、html),多股票结果使用哪种? + - 倾向:grid(美观的边框格式) + +3. **进度条粒度:** tqdm 进度条应该显示每个股票的回测进度,还是仅显示批量回测的总进度? + - 倾向:仅显示批量回测的总进度(股票 N/M) + +4. **图表输出目录结构:** 多股票图表是平铺在 `output/` 下,还是按日期/策略分组? + - 倾向:平铺在 `output/` 下(简单) diff --git a/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/proposal.md b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/proposal.md new file mode 100644 index 0000000..ffb6b1b --- /dev/null +++ b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/proposal.md @@ -0,0 +1,54 @@ +## Why + +当前 `backtest.py` 存在职责混杂的问题:命令行参数解析、核心回测逻辑、数据访问、结果展示都耦合在单一文件中,导致: +- 难以在其他模块中复用回测功能 +- 无法进行单元测试 +- 仅支持单股票回测,无法批量处理 + +需要重构为分层架构,将核心逻辑与 CLI 界面分离,提升代码复用性和可维护性。 + +## What Changes + +- **创建 `config.py`**:集中管理数据库配置、默认回测参数、图表配色 +- **创建 `backtest_core.py`**:核心回测引擎 + - 提供标准化接口 `run_backtest()`(单股票) + - 提供批量接口 `run_batch_backtest()`(多股票,串行执行) + - 封装数据访问和策略加载逻辑 + - 返回结构化结果对象 `BacktestResult` +- **创建 `backtest_command.py`**:命令行界面 + - 支持多股票代码参数 `--codes`(接受多个值) + - 使用 `tabulate` 优化批量结果的表格展示 + - 使用 `tqdm` 显示批量回测进度条 + - 保留原有的单股票详细输出格式 +- **删除 `backtest.py`**:不再需要,功能已迁移 +- **依赖更新**:添加 `tabulate`、`tqdm` 到项目依赖 + +## Capabilities + +### New Capabilities +- `batch-backtest`: 批量回测功能,支持传入多个股票代码进行串行回测,并提供进度条和表格化结果展示 + +### Modified Capabilities +- 无(其他均为实现重构,不改变 spec 级别行为) + +## Impact + +- **代码影响**: + - 删除 `backtest.py`(284 行) + - 新增 `config.py`(约 30 行) + - 新增 `backtest_core.py`(约 250 行) + - 新增 `backtest_command.py`(约 150 行) +- **API 变化**: + - 新增 `run_backtest(code, start_date, end_date, strategy_file, ...)` 函数 + - 新增 `run_batch_backtest(codes, start_date, end_date, strategy_file, ...)` 函数 + - 新增 `BacktestResult` 数据类 +- **命令行变化**: + - 单参数 `--code` 改为多值参数 `--codes` + - 新增 `--output-dir` 参数,为每个股票生成独立 HTML 图表 + - 批量回测时显示进度条和表格化结果 +- **依赖变化**: + - 新增 `tabulate`(表格格式化) + - 新增 `tqdm`(进度条显示) +- **兼容性**: + - **BREAKING**: 删除原有 `backtest.py`,命令行使用方式从 `python backtest.py` 改为 `uv run python backtest_command.py` + - 参数名称从 `--code` 改为 `--codes` diff --git a/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/specs/batch-backtest/spec.md b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/specs/batch-backtest/spec.md new file mode 100644 index 0000000..81de9d9 --- /dev/null +++ b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/specs/batch-backtest/spec.md @@ -0,0 +1,310 @@ +# Spec: Batch Backtest + +## ADDED Requirements + +### Requirement: 多股票回测参数 +系统 SHALL 支持通过命令行参数传入多个股票代码进行批量回测。 + +#### Scenario: 传入多个股票代码 +- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ 600000.SH --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py` +- **THEN** 系统解析所有股票代码到列表 `['000001.SZ', '600000.SH']` +- **THEN** 系统按顺序依次执行每个股票的回测 +- **THEN** 系统为每个股票生成独立的回测结果 + +#### Scenario: 传入单个股票代码 +- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py` +- **THEN** 系统解析为单个股票代码列表 `['000001.SZ']` +- **THEN** 系统执行单个股票回测 +- **THEN** 系统输出详细格式的回测结果 + +#### Scenario: 缺少 --codes 参数 +- **WHEN** 用户未提供 `--codes` 参数 +- **THEN** 系统输出错误信息:"错误: 需要以下参数: --codes" +- **THEN** 系统退出并返回非零状态码 + +--- + +### Requirement: 批量回测执行 +系统 SHALL 串行执行多个股票的回测,每次加载一个股票的数据并执行回测。 + +#### Scenario: 成功执行多个股票回测 +- **WHEN** 用户传入 N 个股票代码 +- **THEN** 系统循环 N 次,每次加载一个股票的数据 +- **THEN** 系统每次执行完整的回测流程(数据加载、指标计算、回测执行) +- **THEN** 系统每次执行完成后生成 `BacktestResult` 对象 +- **THEN** 系统返回包含 N 个 `BacktestResult` 的列表 + +#### Scenario: 每个股票独立预热期 +- **WHEN** 系统执行第 i 个股票的回测 +- **THEN** 系统使用 `start_date - warmup_days` 计算该股票的预热开始日期 +- **THEN** 系统独立加载该股票的预热期数据 +- **THEN** 不同股票的预热期互不影响 + +#### Scenario: 第一个股票回测失败 +- **WHEN** 系统执行第一个股票回测时发生错误(数据库连接失败、策略加载失败等) +- **THEN** 系统捕获异常并输出错误信息 +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码(立即失败策略) + +#### Scenario: 中间股票回测失败 +- **WHEN** 系统执行第 i 个股票回测时发生错误 +- **THEN** 系统输出错误信息(包含股票代码) +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 资源管理 +- **WHEN** 系统完成第 i 个股票的回测 +- **THEN** 系统关闭该股票的数据库连接(`engine.dispose()`) +- **THEN** 系统释放该股票的数据内存 +- **THEN** 系统开始加载第 i+1 个股票的数据 + +--- + +### Requirement: 批量回测进度显示 +系统 SHALL 使用 tqdm 显示批量回测的实时进度,提供用户反馈。 + +#### Scenario: 显示进度条 +- **WHEN** 系统开始执行 N 个股票的批量回测 +- **THEN** 系统显示进度条格式:`回测进度: 25%|█████▌ | 1/4 [00:30<01:30, 12.5s/it]` +- **THEN** 系统在完成每个股票回测后更新进度条 +- **THEN** 进度条显示当前进度(i/N)、已用时间、预计剩余时间 +- **THEN** 进度条在所有股票回测完成后消失 + +#### Scenario: 单股票回测不显示进度条 +- **WHEN** 用户传入单个股票代码 +- **THEN** 系统不显示 tqdm 进度条 +- **THEN** 系统直接输出回测结果 + +#### Scenario: 进度条描述文本 +- **WHEN** 系统显示批量回测进度 +- **THEN** 进度条描述 SHALL 为 "回测进度"(中文) +- **THEN** 进度条显示已完成/总数(如 "1/4", "2/4") + +--- + +### Requirement: 批量回测结果展示 +系统 SHALL 使用 tabulate 将多个股票的回测结果格式化为表格,便于横向对比。 + +#### Scenario: 表格化输出多股票结果 +- **WHEN** 用户传入多个股票代码且回测成功 +- **THEN** 系统使用 tabulate 生成表格 +- **THEN** 表格格式 SHALL 为 grid(带边框) +- **THEN** 表格列 SHALL 包含:股票代码、收益率%、胜率%、最大回撤%、交易次数、SQN +- **THEN** 系统在表格上方显示表头(中文列名) +- **THEN** 数值保留 2 位小数(交易次数为整数) + +#### Scenario: 表格内容填充 +- **WHEN** 系统格式化第 i 个股票的结果 +- **THEN** 系统从 `BacktestResult` 对象提取字段 +- **THEN** "股票代码" 列填充 `result.code` +- **THEN** "收益率%" 列填充 `result.return_pct` +- **THEN** "胜率%" 列填充 `result.win_rate` +- **THEN** "最大回撤%" 列填充 `result.max_drawdown` +- **THEN** "交易次数" 列填充 `result.trades` +- **THEN** "SQN" 列填充 `result.sqn` + +#### Scenario: 单股票回测不使用表格 +- **WHEN** 用户传入单个股票代码 +- **THEN** 系统不使用 tabulate 生成表格 +- **THEN** 系统使用详细格式输出(每个指标单独一行) +- **THEN** 系统保持原有 `print_stats()` 的输出格式 + +#### Scenario: 表格示例输出 +- **WHEN** 用户传入 2 个股票代码 +- **THEN** 系统输出格式 SHALL 为: + ``` + +-------------+-----------+--------+------------+----------+-------+ + | 股票代码 | 收益率% | 胜率% | 最大回撤% | 交易次数 | SQN | + +-------------+-----------+--------+------------+----------+-------+ + | 000001.SZ | 20.35 | 55.00 | -8.50 | 45 | 1.85 | + | 600000.SH | 15.00 | 48.00 | -12.30 | 38 | 1.42 | + +-------------+-----------+--------+------------+----------+-------+ + ``` + +--- + +### Requirement: 多股票图表输出 +系统 SHALL 为每个股票生成独立的 HTML 图表文件,文件名格式为 `{code}.html`。 + +#### Scenario: 指定 --output-dir 参数 +- **WHEN** 用户传入 `--output-dir output/` +- **THEN** 系统为每个股票生成 HTML 文件到 `output/{code}.html` +- **THEN** 文件名 SHALL 为股票代码,如 `000001.SZ.html`, `600000.SH.html` +- **THEN** 系统自动创建 `output/` 目录(`exist_ok=True`) +- **THEN** 系统在完成后输出提示:"图表已保存到目录: output/" 后列出所有文件 + +#### Scenario: 未指定 --output-dir 参数 +- **WHEN** 用户未传入 `--output-dir` 参数 +- **THEN** 系统不为任何股票生成图表文件 +- **THEN** 系统仅输出控制台统计信息 + +#### Scenario: 图表文件覆盖 +- **WHEN** 系统再次执行相同的批量回测 +- **THEN** 系统覆盖已存在的 HTML 文件 +- **THEN** 系统不提示文件已存在 + +--- + +### Requirement: 结构化回测结果 +系统 SHALL 返回标准化的 `BacktestResult` 对象,包含所有关键指标。 + +#### Scenario: BacktestResult 对象创建 +- **WHEN** 系统完成单股票回测 +- **THEN** 系统从 `stats` 对象提取指标到 `BacktestResult` +- **THEN** `BacktestResult.code` 设置为股票代码 +- **THEN** `BacktestResult.start_date` 设置为回测开始日期 +- **THEN** `BacktestResult.end_date` 设置为回测结束日期 +- **THEN** `BacktestResult.equity_final` 设置为最终权益 +- **THEN** `BacktestResult.equity_peak` 设置为峰值收益 +- **THEN** `BacktestResult.return_pct` 设置为总收益率 +- **THEN** `BacktestResult.buy_hold_return` 设置为买入持有收益率 +- **THEN** `BacktestResult.return_annual` 设置为年化收益率 +- **THEN** `BacktestResult.volatility_annual` 设置为年化波动率 +- **THEN** `BacktestResult.max_drawdown` 设置为最大回撤 +- **THEN** `BacktestResult.avg_drawdown` 设置为平均回撤 +- **THEN** `BacktestResult.max_drawdown_duration` 设置为最大回撤持续时长 +- **THEN** `BacktestResult.avg_drawdown_duration` 设置为平均回撤持续时长 +- **THEN** `BacktestResult.sortino_ratio` 设置为索提诺比率 +- **THEN** `BacktestResult.calmar_ratio` 设置为卡尔玛比率 +- **THEN** `BacktestResult.trades` 设置为交易次数 +- **THEN** `BacktestResult.win_rate` 设置为胜率 +- **THEN** `BacktestResult.sqn` 设置为系统质量数 +- **THEN** `BacktestResult.cash` 设置为初始资金 +- **THEN** `BacktestResult.commission` 设置为手续费率 + +#### Scenario: BacktestResult 列表返回 +- **WHEN** 系统完成批量回测 +- **THEN** 系统返回 `List[BacktestResult]` +- **THEN** 列表顺序 SHALL 与输入股票代码顺序一致 +- **THEN** 列表长度 SHALL 等于输入股票代码数量(成功时) + +#### Scenario: BacktestResult 数据类型 +- **WHEN** 系统创建 `BacktestResult` 对象 +- **THEN** 数值字段 SHALL 为 float 类型(除 `trades`, `max_drawdown_duration` 为 int) +- **THEN** 日期字段 SHALL 为 str 类型(YYYY-MM-DD 格式) +- **THEN** 系统支持 `result.to_dict()` 方法(dataclass 自动生成) + +--- + +### Requirement: 可复用回测引擎接口 +系统 SHALL 提供标准化的函数接口,供其他模块调用回测功能。 + +#### Scenario: run_backtest 函数调用 +- **WHEN** 其他模块调用 `run_backtest(code, start_date, end_date, strategy_file, cash, commission, warmup_days, output_file)` +- **THEN** 函数接收股票代码、日期范围、策略文件、回测参数、输出文件路径 +- **THEN** 函数执行完整回测流程(数据加载、策略加载、指标计算、回测执行) +- **THEN** 函数返回 `BacktestResult` 对象 +- **THEN** 函数不打印任何输出(纯函数) + +#### Scenario: run_batch_backtest 函数调用 +- **WHEN** 其他模块调用 `run_batch_backtest(codes, start_date, end_date, strategy_file, cash, commission, warmup_days, output_dir)` +- **THEN** 函数接收股票代码列表、日期范围、策略文件、回测参数、输出目录 +- **THEN** 函数串行执行每个股票的回测 +- **THEN** 函数返回 `List[BacktestResult]` +- **THEN** 函数显示 tqdm 进度条(批量时) + +#### Scenario: 函数参数默认值 +- **WHEN** 调用者不指定可选参数 +- **THEN** `cash` 默认为 100000 +- **THEN** `commission` 默认为 0.002 +- **THEN** `warmup_days` 默认为 365 +- **THEN** `output_file` 默认为 None(不生成图表) +- **THEN** `output_dir` 默认为 None(不生成图表) + +#### Scenario: 函数异常抛出 +- **WHEN** `run_backtest` 或 `run_batch_backtest` 执行时发生错误 +- **THEN** 函数 SHALL 抛出异常(不捕获) +- **THEN** 异常类型 SHALL 为 ValueError、TypeError 或原始异常 +- **THEN** 异常信息 SHALL 包含具体错误原因 +- **THEN** 调用者负责捕获和处理异常 + +--- + +### Requirement: 集中配置管理 +系统 SHALL 在 config.py 中集中管理数据库配置、默认回测参数、图表配色。 + +#### Scenario: 数据库配置访问 +- **WHEN** backtest_core.py 需要数据库连接参数 +- **THEN** 模块从 config 导入 `DB_HOST`, `DB_PORT`, `DB_NAME`, `DB_USER`, `DB_PASSWORD` +- **THEN** 模块使用这些常量构建连接字符串 +- **THEN** 模块不重复定义数据库配置 + +#### Scenario: 默认参数访问 +- **WHEN** backtest_core.py 需要默认回测参数 +- **THEN** 模块从 config 导入 `DEFAULT_CASH`, `DEFAULT_COMMISSION`, `DEFAULT_WARMUP_DAYS` +- **THEN** 模块使用这些常量作为函数默认值 +- **THEN** 模块不重复定义默认参数 + +#### Scenario: 图表配色访问 +- **WHEN** backtest_core.py 需要设置图表配色 +- **THEN** 模块从 config 导入 `BULL_COLOR`, `BEAR_COLOR` +- **THEN** 模块使用这些颜色设置 `plotting.BULL_COLOR` 和 `plotting.BEAR_COLOR` +- **THEN** 模块不重复定义颜色配置 + +#### Scenario: 配置文件内容 +- **WHEN** 查看 config.py 文件 +- **THEN** 文件包含数据库配置(DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD) +- **THEN** 文件包含默认回测参数(DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS) +- **THEN** 文件包含图表配色(BULL_COLOR, BEAR_COLOR) +- **THEN** 所有配置使用明文常量(不使用环境变量) + +--- + +### Requirement: 错误处理策略 +系统 SHALL 在批量回测失败时立即停止执行,不继续处理后续股票。 + +#### Scenario: 数据加载失败 +- **WHEN** 系统加载第 i 个股票数据时失败(数据库错误、数据不存在) +- **THEN** 系统捕获异常 +- **THEN** 系统输出错误信息:"回测失败 [{code}]: {error}" +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 策略加载失败 +- **WHEN** 系统加载策略文件时失败(文件不存在、接口不完整) +- **THEN** 系统捕获异常 +- **THEN** 系统输出错误信息:"策略加载失败: {error}" +- **THEN** 系统停止执行所有股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 回测执行失败 +- **WHEN** 系统执行第 i 个股票回测时失败(策略逻辑错误) +- **THEN** 系统捕获异常 +- **THEN** 系统输出错误信息和完整堆栈跟踪 +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 图表生成失败 +- **WHEN** 系统生成第 i 个股票图表时失败 +- **THEN** 系统捕获异常 +- **THEN** 系统输出警告:"图表生成失败 [{code}]: {error},但回测已完成" +- **THEN** 系统继续执行后续股票的回测 +- **THEN** 系统在返回的 `BacktestResult` 中设置 `error` 字段(如果设计支持) + +--- + +### Requirement: 依赖管理 +系统 SHALL 在 pyproject.toml 中添加 tabulate 和 tqdm 依赖。 + +#### Scenario: 添加 tabulate 依赖 +- **WHEN** 查看 pyproject.toml 文件 +- **THEN** 文件包含 `tabulate` 依赖 +- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本 +- **THEN** 系统可以导入 `import tabulate` 无错误 + +#### Scenario: 添加 tqdm 依赖 +- **WHEN** 查看 pyproject.toml 文件 +- **THEN** 文件包含 `tqdm` 依赖 +- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本 +- **THEN** 系统可以导入 `from tqdm import tqdm` 无错误 + +#### Scenario: 依赖安装 +- **WHEN** 用户运行 `uv sync` 或 `pip install -e .` +- **THEN** 系统自动安装 tabulate 和 tqdm +- **THEN** 系统显示依赖安装进度 +- **THEN** 系统完成安装后可以正常使用回测工具 + +#### Scenario: 依赖缺失提示 +- **WHEN** 系统导入 tabulate 或 tqdm 时失败 +- **THEN** 系统输出友好错误信息:"缺少依赖: {package_name},请运行: uv add {package_name}" +- **THEN** 系统退出并返回非零状态码 diff --git a/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/tasks.md b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/tasks.md new file mode 100644 index 0000000..0f3ef49 --- /dev/null +++ b/openspec/changes/archive/2026-01-28-refactor-backtest-separate-cli/tasks.md @@ -0,0 +1,96 @@ +## 1. 依赖管理 + +- [x] 1.1 运行 `uv add tabulate` 添加依赖 +- [x] 1.2 运行 `uv add tqdm` 添加依赖 +- [x] 1.3 运行 `uv sync` 同步依赖 + +## 2. 配置管理模块 + +- [x] 2.1 创建 config.py 文件 +- [x] 2.2 在 config.py 中定义数据库配置常量(DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD) +- [x] 2.3 在 config.py 中定义默认回测参数(DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS) +- [x] 2.4 在 config.py 中定义图表配色(BULL_COLOR, BEAR_COLOR) +- [x] 2.5 测试 config.py 导入无错误 + +## 3. 核心回测引擎 + +- [x] 3.1 创建 backtest_core.py 文件 +- [x] 3.2 在 backtest_core.py 中导入必要模块和 config +- [x] 3.3 定义 BacktestResult dataclass(包含所有回测指标字段) +- [x] 3.4 迁移 load_data_from_db() 函数(使用 config 数据库配置) +- [x] 3.5 迁移 load_strategy() 函数(保持原有逻辑) +- [x] 3.6 迁移 apply_color_scheme() 函数(使用 config 配色) +- [x] 3.7 实现 run_backtest() 函数(单股票回测) + - [x] 3.7.1 实现预热期日期计算逻辑 + - [x] 3.7.2 实现数据加载和策略加载调用 + - [x] 3.7.3 实现指标计算和数据截取 + - [x] 3.7.4 实现 Backtest 执行 + - [x] 3.7.5 实现图表生成(可选) + - [x] 3.7.6 实现 BacktestResult 对象构建和返回 +- [x] 3.8 实现 run_batch_backtest() 函数(批量回测,串行) + - [x] 3.8.1 实现循环遍历股票代码 + - [x] 3.8.2 实现为每个股票调用 run_backtest() + - [x] 3.8.3 实现为每个股票生成独立 HTML 文件 + - [x] 3.8.4 实现结果列表收集和返回 + - [x] 3.8.5 实现 tqdm 进度条显示(批量时) +- [x] 3.9 测试 run_backtest() 单股票回测 +- [x] 3.10 测试 run_batch_backtest() 多股票回测 + +## 4. CLI 界面模块 + +- [x] 4.1 创建 backtest_command.py 文件 +- [x] 4.2 在 backtest_command.py 中导入必要模块和 backtest_core +- [x] 4.3 实现 parse_arguments() 函数 + - [x] 4.3.1 定义 --codes 多值参数(nargs='+') + - [x] 4.3.2 定义 --output-dir 目录参数 + - [x] 4.3.3 保持原有参数(--start-date, --end-date, --strategy-file, --cash, --commission, --warmup-days) + - [x] 4.3.4 添加参数帮助文档和示例 +- [x] 4.4 实现 format_single_result() 函数(详细格式输出) + - [x] 4.4.1 实现每个指标单独一行的格式化 + - [x] 4.4.2 保持原有 print_stats() 的输出格式 +- [x] 4.5 实现 format_batch_results() 函数(表格格式输出) + - [x] 4.5.1 实现使用 tabulate 生成表格 + - [x] 4.5.2 定义表格列:股票代码、收益率%、胜率%、最大回撤%、交易次数、SQN + - [x] 4.5.3 实现表格数据填充(从 BacktestResult 对象提取) + - [x] 4.5.4 实现表格格式为 grid +- [x] 4.6 实现 main() 函数 + - [x] 4.6.1 调用 parse_arguments() 解析参数 + - [x] 4.6.2 调用 run_batch_backtest() 执行批量回测 + - [x] 4.6.3 根据结果数量调用 format_single_result() 或 format_batch_results() + - [x] 4.6.4 实现图表保存提示(指定 --output-dir 时) + - [x] 4.6.5 实现错误捕获和友好错误信息输出 + - [x] 4.6.6 实现退出状态码设置(成功 0,失败非零) +- [x] 4.7 添加 `if __name__ == "__main__": main()` 入口 +- [x] 4.8 测试单股票回测命令行调用 (`uv run python backtest_command.py`) +- [x] 4.9 测试多股票回测命令行调用 (`uv run python backtest_command.py`) +- [x] 4.10 测试错误处理(参数缺失、文件不存在等) + +## 5. 清理旧代码 + +- [x] 5.1 确认新功能完整(单股票、多股票、图表输出) +- [x] 5.2 确认错误处理正确(立即失败) +- [x] 5.3 删除 backtest.py 文件 +- [x] 5.4 验证 git 状态(仅删除旧文件,无其他修改) + +## 6. 文档更新 + +- [x] 6.1 更新 README.md(如果存在) + - [x] 6.1.1 说明新的命令行使用方式(`uv run python backtest_command.py`) + - [x] 6.1.2 说明参数变化(--code 改为 --codes) + - [x] 6.1.3 提供单股票和多股票示例 + - [x] 6.1.4 说明 --output-dir 用法(多股票图表) +- [x] 6.2 创建 note_refactor.md(可选,记录重构说明) + - [x] 6.2.1 说明文件结构变化 + - [x] 6.2.2 说明接口变化 + - [x] 6.2.3 提供迁移指南 + +## 7. 集成测试 + +- [x] 7.1 测试单个股票完整流程(000001.SZ) +- [x] 7.2 测试多个股票完整流程(000001.SZ 600000.SH) +- [x] 7.3 测试指定 --output-dir 生成图表 +- [x] 7.4 测试不指定 --output-dir(不生成图表) +- [x] 7.5 测试错误情况(无效股票代码、不存在的策略文件等) +- [x] 7.6 验证进度条显示(多股票时) +- [x] 7.7 验证表格格式输出(多股票时) +- [x] 7.8 验证详细格式输出(单股票时) diff --git a/openspec/specs/batch-backtest/spec.md b/openspec/specs/batch-backtest/spec.md new file mode 100644 index 0000000..a4cc0f1 --- /dev/null +++ b/openspec/specs/batch-backtest/spec.md @@ -0,0 +1,312 @@ +# batch-backtest Specification + +## Purpose +TBD - created by archiving change refactor-backtest-separate-cli. Update Purpose after archive. +## Requirements +### Requirement: 多股票回测参数 +系统 SHALL 支持通过命令行参数传入多个股票代码进行批量回测。 + +#### Scenario: 传入多个股票代码 +- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ 600000.SH --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py` +- **THEN** 系统解析所有股票代码到列表 `['000001.SZ', '600000.SH']` +- **THEN** 系统按顺序依次执行每个股票的回测 +- **THEN** 系统为每个股票生成独立的回测结果 + +#### Scenario: 传入单个股票代码 +- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py` +- **THEN** 系统解析为单个股票代码列表 `['000001.SZ']` +- **THEN** 系统执行单个股票回测 +- **THEN** 系统输出详细格式的回测结果 + +#### Scenario: 缺少 --codes 参数 +- **WHEN** 用户未提供 `--codes` 参数 +- **THEN** 系统输出错误信息:"错误: 需要以下参数: --codes" +- **THEN** 系统退出并返回非零状态码 + +--- + +### Requirement: 批量回测执行 +系统 SHALL 串行执行多个股票的回测,每次加载一个股票的数据并执行回测。 + +#### Scenario: 成功执行多个股票回测 +- **WHEN** 用户传入 N 个股票代码 +- **THEN** 系统循环 N 次,每次加载一个股票的数据 +- **THEN** 系统每次执行完整的回测流程(数据加载、指标计算、回测执行) +- **THEN** 系统每次执行完成后生成 `BacktestResult` 对象 +- **THEN** 系统返回包含 N 个 `BacktestResult` 的列表 + +#### Scenario: 每个股票独立预热期 +- **WHEN** 系统执行第 i 个股票的回测 +- **THEN** 系统使用 `start_date - warmup_days` 计算该股票的预热开始日期 +- **THEN** 系统独立加载该股票的预热期数据 +- **THEN** 不同股票的预热期互不影响 + +#### Scenario: 第一个股票回测失败 +- **WHEN** 系统执行第一个股票回测时发生错误(数据库连接失败、策略加载失败等) +- **THEN** 系统捕获异常并输出错误信息 +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码(立即失败策略) + +#### Scenario: 中间股票回测失败 +- **WHEN** 系统执行第 i 个股票回测时发生错误 +- **THEN** 系统输出错误信息(包含股票代码) +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 资源管理 +- **WHEN** 系统完成第 i 个股票的回测 +- **THEN** 系统关闭该股票的数据库连接(`engine.dispose()`) +- **THEN** 系统释放该股票的数据内存 +- **THEN** 系统开始加载第 i+1 个股票的数据 + +--- + +### Requirement: 批量回测进度显示 +系统 SHALL 使用 tqdm 显示批量回测的实时进度,提供用户反馈。 + +#### Scenario: 显示进度条 +- **WHEN** 系统开始执行 N 个股票的批量回测 +- **THEN** 系统显示进度条格式:`回测进度: 25%|█████▌ | 1/4 [00:30<01:30, 12.5s/it]` +- **THEN** 系统在完成每个股票回测后更新进度条 +- **THEN** 进度条显示当前进度(i/N)、已用时间、预计剩余时间 +- **THEN** 进度条在所有股票回测完成后消失 + +#### Scenario: 单股票回测不显示进度条 +- **WHEN** 用户传入单个股票代码 +- **THEN** 系统不显示 tqdm 进度条 +- **THEN** 系统直接输出回测结果 + +#### Scenario: 进度条描述文本 +- **WHEN** 系统显示批量回测进度 +- **THEN** 进度条描述 SHALL 为 "回测进度"(中文) +- **THEN** 进度条显示已完成/总数(如 "1/4", "2/4") + +--- + +### Requirement: 批量回测结果展示 +系统 SHALL 使用 tabulate 将多个股票的回测结果格式化为表格,便于横向对比。 + +#### Scenario: 表格化输出多股票结果 +- **WHEN** 用户传入多个股票代码且回测成功 +- **THEN** 系统使用 tabulate 生成表格 +- **THEN** 表格格式 SHALL 为 grid(带边框) +- **THEN** 表格列 SHALL 包含:股票代码、收益率%、胜率%、最大回撤%、交易次数、SQN +- **THEN** 系统在表格上方显示表头(中文列名) +- **THEN** 数值保留 2 位小数(交易次数为整数) + +#### Scenario: 表格内容填充 +- **WHEN** 系统格式化第 i 个股票的结果 +- **THEN** 系统从 `BacktestResult` 对象提取字段 +- **THEN** "股票代码" 列填充 `result.code` +- **THEN** "收益率%" 列填充 `result.return_pct` +- **THEN** "胜率%" 列填充 `result.win_rate` +- **THEN** "最大回撤%" 列填充 `result.max_drawdown` +- **THEN** "交易次数" 列填充 `result.trades` +- **THEN** "SQN" 列填充 `result.sqn` + +#### Scenario: 单股票回测不使用表格 +- **WHEN** 用户传入单个股票代码 +- **THEN** 系统不使用 tabulate 生成表格 +- **THEN** 系统使用详细格式输出(每个指标单独一行) +- **THEN** 系统保持原有 `print_stats()` 的输出格式 + +#### Scenario: 表格示例输出 +- **WHEN** 用户传入 2 个股票代码 +- **THEN** 系统输出格式 SHALL 为: + ``` + +-------------+-----------+--------+------------+----------+-------+ + | 股票代码 | 收益率% | 胜率% | 最大回撤% | 交易次数 | SQN | + +-------------+-----------+--------+------------+----------+-------+ + | 000001.SZ | 20.35 | 55.00 | -8.50 | 45 | 1.85 | + | 600000.SH | 15.00 | 48.00 | -12.30 | 38 | 1.42 | + +-------------+-----------+--------+------------+----------+-------+ + ``` + +--- + +### Requirement: 多股票图表输出 +系统 SHALL 为每个股票生成独立的 HTML 图表文件,文件名格式为 `{code}.html`。 + +#### Scenario: 指定 --output-dir 参数 +- **WHEN** 用户传入 `--output-dir output/` +- **THEN** 系统为每个股票生成 HTML 文件到 `output/{code}.html` +- **THEN** 文件名 SHALL 为股票代码,如 `000001.SZ.html`, `600000.SH.html` +- **THEN** 系统自动创建 `output/` 目录(`exist_ok=True`) +- **THEN** 系统在完成后输出提示:"图表已保存到目录: output/" 后列出所有文件 + +#### Scenario: 未指定 --output-dir 参数 +- **WHEN** 用户未传入 `--output-dir` 参数 +- **THEN** 系统不为任何股票生成图表文件 +- **THEN** 系统仅输出控制台统计信息 + +#### Scenario: 图表文件覆盖 +- **WHEN** 系统再次执行相同的批量回测 +- **THEN** 系统覆盖已存在的 HTML 文件 +- **THEN** 系统不提示文件已存在 + +--- + +### Requirement: 结构化回测结果 +系统 SHALL 返回标准化的 `BacktestResult` 对象,包含所有关键指标。 + +#### Scenario: BacktestResult 对象创建 +- **WHEN** 系统完成单股票回测 +- **THEN** 系统从 `stats` 对象提取指标到 `BacktestResult` +- **THEN** `BacktestResult.code` 设置为股票代码 +- **THEN** `BacktestResult.start_date` 设置为回测开始日期 +- **THEN** `BacktestResult.end_date` 设置为回测结束日期 +- **THEN** `BacktestResult.equity_final` 设置为最终权益 +- **THEN** `BacktestResult.equity_peak` 设置为峰值收益 +- **THEN** `BacktestResult.return_pct` 设置为总收益率 +- **THEN** `BacktestResult.buy_hold_return` 设置为买入持有收益率 +- **THEN** `BacktestResult.return_annual` 设置为年化收益率 +- **THEN** `BacktestResult.volatility_annual` 设置为年化波动率 +- **THEN** `BacktestResult.max_drawdown` 设置为最大回撤 +- **THEN** `BacktestResult.avg_drawdown` 设置为平均回撤 +- **THEN** `BacktestResult.max_drawdown_duration` 设置为最大回撤持续时长 +- **THEN** `BacktestResult.avg_drawdown_duration` 设置为平均回撤持续时长 +- **THEN** `BacktestResult.sortino_ratio` 设置为索提诺比率 +- **THEN** `BacktestResult.calmar_ratio` 设置为卡尔玛比率 +- **THEN** `BacktestResult.trades` 设置为交易次数 +- **THEN** `BacktestResult.win_rate` 设置为胜率 +- **THEN** `BacktestResult.sqn` 设置为系统质量数 +- **THEN** `BacktestResult.cash` 设置为初始资金 +- **THEN** `BacktestResult.commission` 设置为手续费率 + +#### Scenario: BacktestResult 列表返回 +- **WHEN** 系统完成批量回测 +- **THEN** 系统返回 `List[BacktestResult]` +- **THEN** 列表顺序 SHALL 与输入股票代码顺序一致 +- **THEN** 列表长度 SHALL 等于输入股票代码数量(成功时) + +#### Scenario: BacktestResult 数据类型 +- **WHEN** 系统创建 `BacktestResult` 对象 +- **THEN** 数值字段 SHALL 为 float 类型(除 `trades`, `max_drawdown_duration` 为 int) +- **THEN** 日期字段 SHALL 为 str 类型(YYYY-MM-DD 格式) +- **THEN** 系统支持 `result.to_dict()` 方法(dataclass 自动生成) + +--- + +### Requirement: 可复用回测引擎接口 +系统 SHALL 提供标准化的函数接口,供其他模块调用回测功能。 + +#### Scenario: run_backtest 函数调用 +- **WHEN** 其他模块调用 `run_backtest(code, start_date, end_date, strategy_file, cash, commission, warmup_days, output_file)` +- **THEN** 函数接收股票代码、日期范围、策略文件、回测参数、输出文件路径 +- **THEN** 函数执行完整回测流程(数据加载、策略加载、指标计算、回测执行) +- **THEN** 函数返回 `BacktestResult` 对象 +- **THEN** 函数不打印任何输出(纯函数) + +#### Scenario: run_batch_backtest 函数调用 +- **WHEN** 其他模块调用 `run_batch_backtest(codes, start_date, end_date, strategy_file, cash, commission, warmup_days, output_dir)` +- **THEN** 函数接收股票代码列表、日期范围、策略文件、回测参数、输出目录 +- **THEN** 函数串行执行每个股票的回测 +- **THEN** 函数返回 `List[BacktestResult]` +- **THEN** 函数显示 tqdm 进度条(批量时) + +#### Scenario: 函数参数默认值 +- **WHEN** 调用者不指定可选参数 +- **THEN** `cash` 默认为 100000 +- **THEN** `commission` 默认为 0.002 +- **THEN** `warmup_days` 默认为 365 +- **THEN** `output_file` 默认为 None(不生成图表) +- **THEN** `output_dir` 默认为 None(不生成图表) + +#### Scenario: 函数异常抛出 +- **WHEN** `run_backtest` 或 `run_batch_backtest` 执行时发生错误 +- **THEN** 函数 SHALL 抛出异常(不捕获) +- **THEN** 异常类型 SHALL 为 ValueError、TypeError 或原始异常 +- **THEN** 异常信息 SHALL 包含具体错误原因 +- **THEN** 调用者负责捕获和处理异常 + +--- + +### Requirement: 集中配置管理 +系统 SHALL 在 config.py 中集中管理数据库配置、默认回测参数、图表配色。 + +#### Scenario: 数据库配置访问 +- **WHEN** backtest_core.py 需要数据库连接参数 +- **THEN** 模块从 config 导入 `DB_HOST`, `DB_PORT`, `DB_NAME`, `DB_USER`, `DB_PASSWORD` +- **THEN** 模块使用这些常量构建连接字符串 +- **THEN** 模块不重复定义数据库配置 + +#### Scenario: 默认参数访问 +- **WHEN** backtest_core.py 需要默认回测参数 +- **THEN** 模块从 config 导入 `DEFAULT_CASH`, `DEFAULT_COMMISSION`, `DEFAULT_WARMUP_DAYS` +- **THEN** 模块使用这些常量作为函数默认值 +- **THEN** 模块不重复定义默认参数 + +#### Scenario: 图表配色访问 +- **WHEN** backtest_core.py 需要设置图表配色 +- **THEN** 模块从 config 导入 `BULL_COLOR`, `BEAR_COLOR` +- **THEN** 模块使用这些颜色设置 `plotting.BULL_COLOR` 和 `plotting.BEAR_COLOR` +- **THEN** 模块不重复定义颜色配置 + +#### Scenario: 配置文件内容 +- **WHEN** 查看 config.py 文件 +- **THEN** 文件包含数据库配置(DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD) +- **THEN** 文件包含默认回测参数(DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS) +- **THEN** 文件包含图表配色(BULL_COLOR, BEAR_COLOR) +- **THEN** 所有配置使用明文常量(不使用环境变量) + +--- + +### Requirement: 错误处理策略 +系统 SHALL 在批量回测失败时立即停止执行,不继续处理后续股票。 + +#### Scenario: 数据加载失败 +- **WHEN** 系统加载第 i 个股票数据时失败(数据库错误、数据不存在) +- **THEN** 系统捕获异常 +- **THEN** 系统输出错误信息:"回测失败 [{code}]: {error}" +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 策略加载失败 +- **WHEN** 系统加载策略文件时失败(文件不存在、接口不完整) +- **THEN** 系统捕获异常 +- **THEN** 系统输出错误信息:"策略加载失败: {error}" +- **THEN** 系统停止执行所有股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 回测执行失败 +- **WHEN** 系统执行第 i 个股票回测时失败(策略逻辑错误) +- **THEN** 系统捕获异常 +- **THEN** 系统输出错误信息和完整堆栈跟踪 +- **THEN** 系统停止执行后续股票的回测 +- **THEN** 系统退出并返回非零状态码 + +#### Scenario: 图表生成失败 +- **WHEN** 系统生成第 i 个股票图表时失败 +- **THEN** 系统捕获异常 +- **THEN** 系统输出警告:"图表生成失败 [{code}]: {error},但回测已完成" +- **THEN** 系统继续执行后续股票的回测 +- **THEN** 系统在返回的 `BacktestResult` 中设置 `error` 字段(如果设计支持) + +--- + +### Requirement: 依赖管理 +系统 SHALL 在 pyproject.toml 中添加 tabulate 和 tqdm 依赖。 + +#### Scenario: 添加 tabulate 依赖 +- **WHEN** 查看 pyproject.toml 文件 +- **THEN** 文件包含 `tabulate` 依赖 +- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本 +- **THEN** 系统可以导入 `import tabulate` 无错误 + +#### Scenario: 添加 tqdm 依赖 +- **WHEN** 查看 pyproject.toml 文件 +- **THEN** 文件包含 `tqdm` 依赖 +- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本 +- **THEN** 系统可以导入 `from tqdm import tqdm` 无错误 + +#### Scenario: 依赖安装 +- **WHEN** 用户运行 `uv sync` 或 `pip install -e .` +- **THEN** 系统自动安装 tabulate 和 tqdm +- **THEN** 系统显示依赖安装进度 +- **THEN** 系统完成安装后可以正常使用回测工具 + +#### Scenario: 依赖缺失提示 +- **WHEN** 系统导入 tabulate 或 tqdm 时失败 +- **THEN** 系统输出友好错误信息:"缺少依赖: {package_name},请运行: uv add {package_name}" +- **THEN** 系统退出并返回非零状态码 + diff --git a/pyproject.toml b/pyproject.toml index a5b6809..7268ea0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,4 +15,6 @@ dependencies = [ "psycopg2-binary~=2.9.11", "sqlalchemy>=2.0.46", "ta-lib>=0.6.8", + "tabulate>=0.9.0", + "tqdm>=4.67.1", ] diff --git a/uv.lock b/uv.lock index 478ea43..0261d48 100644 --- a/uv.lock +++ b/uv.lock @@ -912,6 +912,8 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "sqlalchemy" }, { name = "ta-lib" }, + { name = "tabulate" }, + { name = "tqdm" }, ] [package.metadata] @@ -927,6 +929,8 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "~=2.9.11" }, { name = "sqlalchemy", specifier = ">=2.0.46" }, { name = "ta-lib", specifier = ">=0.6.8" }, + { name = "tabulate", specifier = ">=0.9.0" }, + { name = "tqdm", specifier = ">=4.67.1" }, ] [[package]] @@ -1692,6 +1696,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/4c/d341020377f8b183405bdf3c5717fc2ca04a8d33b5c59b2348377ee459d9/ta_lib-0.6.8-cp314-cp314-win_arm64.whl", hash = "sha256:bfad1202fb1f9140e3810cc607058395f59032d9128cc0d716900c78bea5f337", size = 755896, upload-time = "2025-10-20T20:49:39.9Z" }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" }, +] + [[package]] name = "terminado" version = "0.18.1" @@ -1737,6 +1750,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/49/8dc3fd90902f70084bd2cd059d576ddb4f8bb44c2c7c0e33a11422acb17e/tornado-6.5.4-cp39-abi3-win_arm64.whl", hash = "sha256:053e6e16701eb6cbe641f308f4c1a9541f91b6261991160391bfc342e8a551a1", size = 445910, upload-time = "2025-12-15T19:21:02.571Z" }, ] +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + [[package]] name = "traitlets" version = "5.14.3"