1
0

Compare commits

...

7 Commits

18 changed files with 1976 additions and 438 deletions

1
.idea/vcs.xml generated
View File

@@ -2,5 +2,6 @@
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
<mapping directory="$PROJECT_DIR$/backtestingpy" vcs="Git" />
</component>
</project>

File diff suppressed because one or more lines are too long

View File

@@ -1,326 +0,0 @@
#!/usr/bin/env python3
"""
量化回测主程序
使用方法:
python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategy.py
"""
import argparse
import sys
import os
import importlib.util
import pandas as pd
from datetime import datetime
from backtesting import Backtest
# 数据库配置(直接硬编码,开发环境)
DB_HOST = "81.71.3.24"
DB_PORT = 6785
DB_NAME = "leopard_dev"
DB_USER = "leopard"
DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X"
def load_data_from_db(code, start_date, end_date):
"""
从数据库加载历史数据
参数:
code: 股票代码(如 '000001.SZ'
start_date: 开始日期(如 '2024-01-01'
end_date: 结束日期(如 '2025-12-31'
返回:
DataFrame, 包含列: [Open, High, Low, Close, Volume, factor]
"""
import sqlalchemy
import urllib.parse
# 构建连接字符串URL 编码密码中的特殊字符)
encoded_password = urllib.parse.quote_plus(DB_PASSWORD)
conn_str = (
f"postgresql://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
)
engine = sqlalchemy.create_engine(conn_str)
try:
# 构建 SQL 查询
query = f"""
SELECT
trade_date,
open * factor AS "Open",
close * factor AS "Close",
high * factor AS "High",
low * factor AS "Low",
volume AS "Volume",
COALESCE(factor, 1.0) AS factor
FROM leopard_daily daily
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
WHERE stock.code = '{code}'
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
AND '{end_date} 23:59:59'
ORDER BY daily.trade_date
"""
# 执行查询
df = pd.read_sql(query, engine)
if len(df) == 0:
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
# 潬换日期并设置为索引
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
df.set_index("trade_date", inplace=True)
return df
finally:
# 清理连接
engine.dispose()
def load_strategy(strategy_file):
"""
动态加载策略文件
参数:
strategy_file: 策略文件路径 (如 'strategy.py''strategies/macd.py')
返回:
(calculate_indicators, strategy_class) 元组
"""
# 获取模块名
module_name = strategy_file.replace(".py", "").replace("/", ".")
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# 接口验证
if not hasattr(module, "calculate_indicators"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
if not hasattr(module, "get_strategy"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
calculate_indicators = module.calculate_indicators
strategy_class = module.get_strategy()
# 验证 get_strategy 返回的是类
if not isinstance(strategy_class, type):
raise TypeError("get_strategy() 必须返回一个类")
# 验证策略类继承自 backtesting.Strategy
from backtesting import Strategy
if not issubclass(strategy_class, Strategy):
raise TypeError("策略类必须继承 backtesting.Strategy")
return calculate_indicators, strategy_class
def apply_color_scheme():
"""
应用颜色方案:红涨绿跌(中国股市风格)
"""
import backtesting._plotting as plotting
from bokeh.colors.named import tomato, lime
plotting.BULL_COLOR = tomato
plotting.BEAR_COLOR = lime
def parse_arguments():
"""
解析命令行参数
返回:
args: 命名空间对象
"""
parser = argparse.ArgumentParser(
description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter
)
# 必需参数
parser.add_argument(
"--code", type=str, required=True, help="股票代码 (如: 000001.SZ)"
)
parser.add_argument(
"--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)"
)
parser.add_argument(
"--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)"
)
parser.add_argument(
"--strategy-file",
type=str,
required=True,
help="策略文件路径 (如: strategy.py)",
)
# 可选参数
parser.add_argument(
"--cash", type=float, default=100000, help="初始资金 (默认: 100000)"
)
parser.add_argument(
"--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)"
)
parser.add_argument(
"--output", type=str, default=None, help="HTML 输出文件路径 (可选)"
)
parser.add_argument(
"--warmup-days", type=int, default=365, help="预热天数 (默认: 365约一年"
)
return parser.parse_args()
def format_value(value, cn_name, key):
"""
格式化数值显示
"""
if isinstance(value, (int, float)):
if "%" in cn_name or key in [
"Sharpe Ratio",
"Sortino Ratio",
"Calmar Ratio",
"Profit Factor",
]:
formatted_value = f"{value:.2f}"
elif "$" in cn_name:
formatted_value = f"{value:.2f}"
elif "次数" in cn_name:
formatted_value = f"{value:.0f}"
else:
formatted_value = f"{value:.4f}"
else:
formatted_value = str(value)
return formatted_value
def print_stats(stats):
"""
打印回测统计结果
参数:
stats: backtesting 库返回的统计对象
"""
print("\n" + "=" * 60)
print("回测结果")
print("=" * 60)
indicator_name_mapping = {
# 'Start': '回测开始时间',
# 'End': '回测结束时间',
# 'Duration': '回测持续时长',
# 'Exposure Time [%]': '持仓时间占比(%',
'Equity Final [$]': '最终收益',
'Equity Peak [$]': '峰值收益',
'Return [%]': '总收益率(%',
'Buy & Hold Return [%]': '买入并持有收益率(%',
'Return (Ann.) [%]': '年化收益率(%',
'Volatility (Ann.) [%]': '年化波动率(%',
# 'CAGR [%]': '复合年均增长率(%',
# 'Sharpe Ratio': '夏普比率',
'Sortino Ratio': '索提诺比率',
'Calmar Ratio': '卡尔玛比率',
# 'Alpha [%]': '阿尔法系数(%',
# 'Beta': '贝塔系数',
'Max. Drawdown [%]': '最大回撤(%',
'Avg. Drawdown [%]': '平均回撤(%',
'Max. Drawdown Duration': '最大回撤持续时长',
'Avg. Drawdown Duration': '平均回撤持续时长',
'# Trades': '总交易次数',
'Win Rate [%]': '胜率(%',
# 'Best Trade [%]': '最佳单笔交易收益率(%',
# 'Worst Trade [%]': '最差单笔交易收益率(%',
# 'Avg. Trade [%]': '平均单笔交易收益率(%',
# 'Max. Trade Duration': '单笔交易最长持有时长',
# 'Avg. Trade Duration': '单笔交易平均持有时长',
# 'Profit Factor': '盈利因子',
# 'Expectancy [%]': '期望收益(%',
'SQN': '系统质量数',
# 'Kelly Criterion': '凯利准则',
}
for k, v in stats.items():
if k in indicator_name_mapping:
cn_name = indicator_name_mapping.get(k, k)
if isinstance(v, (int, float)):
if "%" in cn_name or k in ['Sharpe Ratio', 'Sortino Ratio', 'Calmar Ratio', 'Profit Factor']:
formatted_value = f"{v:.2f}"
elif "$" in cn_name:
formatted_value = f"{v:.2f}"
elif "次数" in cn_name:
formatted_value = f"{v:.0f}"
else:
formatted_value = f"{v:.4f}"
else:
formatted_value = str(v)
print(f'{cn_name}: {formatted_value}')
print("=" * 60 + "\n")
def main():
"""
主函数:编排完整回测流程
"""
try:
# 解析参数
args = parse_arguments()
apply_color_scheme()
# 加载数据
print(f"加载股票数据: {args.code} ({args.start_date} ~ {args.end_date})")
data = load_data_from_db(args.code, args.start_date, args.end_date)
print(f"数据加载完成,共 {len(data)} 条记录")
# 截取预热数据
warmup_data = data.iloc[-args.warmup_days :]
print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}")
# 加载策略
print(f"加载策略: {args.strategy_file}")
calculate_indicators, strategy_class = load_strategy(args.strategy_file)
# 计算指标
print("计算指标...")
warmup_data = calculate_indicators(warmup_data)
print("指标计算完成")
# 执行回测
print("开始回测...")
from backtesting import Backtest
bt = Backtest(
warmup_data,
strategy_class,
cash=args.cash,
commission=args.commission,
finalize_trades=True,
)
stats = bt.run()
# 输出结果
print_stats(stats)
# 生成图表
if args.output:
print(f"\n生成图表: {args.output}")
bt.plot(filename=args.output, open_browser=False)
print(f"图表已保存到: {args.output}")
print("\n回测完成!")
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

120
backtest_command.py Executable file
View File

@@ -0,0 +1,120 @@
#!/usr/bin/env python3
import argparse
import sys
import tabulate
import backtest_core
def parse_arguments():
parser = argparse.ArgumentParser(description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--codes", type=str, nargs="+", required=True, help="股票代码列表 (如: 000001.SZ 600000.SH)", )
parser.add_argument("--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)", )
parser.add_argument("--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)", )
parser.add_argument("--strategy-file", type=str, required=True, help="策略文件路径 (如: strategy.py)", )
parser.add_argument("--cash", type=float, default=100000, help="初始资金 (默认: 100000)", )
parser.add_argument("--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)", )
parser.add_argument("--warmup-days", type=int, default=365, help="预热天数 (默认: 365约一年)", )
parser.add_argument("--output-dir", type=str, default=None, help="HTML 图表输出目录 (可选,为每个股票生成 {code}.html)", )
return parser.parse_args()
def format_single_result(result: backtest_core.BacktestResult):
print("=" * 60)
print(f"股票代码: {result.code}")
print("=" * 60)
indicator_mapping = {
"最终收益": f"{result.equity_final:.2f}",
"峰值收益": f"{result.equity_peak:.2f}",
"总收益率(%": f"{result.return_pct:.2f}",
"买入并持有收益率(%": f"{result.buy_hold_return_pct:.2f}",
"年化收益率(%": f"{result.return_ann_pct:.2f}",
"年化波动率(%": f"{result.volatility_ann_pct:.2f}",
"索提诺比率": f"{result.sortino_ratio:.2f}",
"卡尔玛比率": f"{result.calmar_ratio:.2f}",
"最大回撤(%": f"{result.max_drawdown_pct:.2f}",
"平均回撤(%": f"{result.avg_drawdown_pct:.2f}",
"最大回撤持续时长": f"{result.max_drawdown_duration:.0f}",
"平均回撤持续时长": f"{result.avg_drawdown_duration:.0f}",
"总交易次数": f"{result.num_trades:.0f}",
"胜率(%": f"{result.win_rate_pct:.2f}",
"系统质量数": f"{result.sqn:.2f}",
}
for name, value in indicator_mapping.items():
print(f"{name}: {value}")
print("=" * 60)
def format_batch_results(results: list[backtest_core.BacktestResult]):
table_data = []
for result in results:
table_data.append(
[
result.code,
f"{result.return_pct:.2f}",
f"{result.buy_hold_return_pct:.2f}",
f"{result.return_ann_pct:.2f}",
f"{result.volatility_ann_pct:.2f}",
f"{result.win_rate_pct:.2f}",
f"{result.max_drawdown_pct:.2f}",
f"{result.sortino_ratio:.2f}",
f"{result.num_trades:.0f}",
f"{result.sqn:.2f}",
]
)
headers = [
"股票代码",
"收益率%",
"买入持有%",
"年化收益%",
"年化波动%",
"胜率%",
"最大回撤%",
"索提诺比率",
"交易次数",
"SQN",
]
print(tabulate.tabulate(table_data, headers=headers, tablefmt="grid"))
def main():
args = parse_arguments()
try:
results = backtest_core.run_batch_backtest(
codes=args.codes,
start_date=args.start_date,
end_date=args.end_date,
strategy_file=args.strategy_file,
cash=args.cash,
commission=args.commission,
warmup_days=args.warmup_days,
output_dir=args.output_dir,
show_progress=True,
)
if len(results) == 1:
format_single_result(results[0])
else:
format_batch_results(results)
if args.output_dir:
print(f"\n图表已保存到: {args.output_dir}/")
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

208
backtest_core.py Normal file
View File

@@ -0,0 +1,208 @@
import dataclasses
import importlib.util
import os
from typing import Optional
import pandas as pd
from tqdm import tqdm
import config
@dataclasses.dataclass
class BacktestResult:
code: str
equity_final: float
equity_peak: float
return_pct: float
buy_hold_return_pct: float
return_ann_pct: float
volatility_ann_pct: float
sortino_ratio: float
calmar_ratio: float
max_drawdown_pct: float
avg_drawdown_pct: float
max_drawdown_duration: float
avg_drawdown_duration: float
num_trades: int
win_rate_pct: float
sqn: float
def load_data_from_db(code: str, start_date: str, end_date: str) -> pd.DataFrame:
import sqlalchemy
import urllib.parse
encoded_password = urllib.parse.quote_plus(config.DB_PASSWORD)
conn_str = f"postgresql://{config.DB_USER}:{encoded_password}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
engine = sqlalchemy.create_engine(conn_str)
try:
query = f"""
SELECT trade_date,
open * factor AS "Open",
close * factor AS "Close",
high * factor AS "High",
low * factor AS "Low",
volume AS "Volume",
COALESCE(factor, 1.0) AS factor
FROM leopard_daily daily
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
WHERE stock.code = '{code}'
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
AND '{end_date} 23:59:59'
ORDER BY daily.trade_date
"""
df = pd.read_sql(query, engine)
if len(df) == 0:
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
df.set_index("trade_date", inplace=True)
return df
finally:
engine.dispose()
def load_strategy(strategy_file: str):
module_name = strategy_file.replace(".py", "").replace("/", ".")
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if not hasattr(module, "calculate_indicators"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
if not hasattr(module, "get_strategy"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
calculate_indicators = module.calculate_indicators
strategy_class = module.get_strategy()
if not isinstance(strategy_class, type):
raise TypeError("get_strategy() 必须返回一个类")
from backtesting import Strategy
if not issubclass(strategy_class, Strategy):
raise TypeError("策略类必须继承 backtesting.Strategy")
return calculate_indicators, strategy_class
def apply_color_scheme():
import backtesting._plotting as plotting
plotting.BULL_COLOR = config.BULL_COLOR
plotting.BEAR_COLOR = config.BEAR_COLOR
def run_backtest(
code: str,
start_date: str,
end_date: str,
strategy_file: str,
cash: float = config.DEFAULT_CASH,
commission: float = config.DEFAULT_COMMISSION,
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
output_dir: Optional[str] = None,
) -> BacktestResult:
warmup_start_date = (pd.to_datetime(start_date) - pd.Timedelta(days=warmup_days)).strftime("%Y-%m-%d")
data = load_data_from_db(code, warmup_start_date, end_date)
calculate_indicators, strategy_class = load_strategy(strategy_file)
data = calculate_indicators(data)
data = data.loc[start_date:end_date]
from backtesting import Backtest
bt = Backtest(data, strategy_class, cash=cash, commission=commission, finalize_trades=True)
stats = bt.run()
apply_color_scheme()
if output_dir:
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{code}.html")
bt.plot(filename=output_path, open_browser=False)
def _safe_float(value, default=0):
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def _safe_int(value, default=0):
if value is None:
return default
try:
return int(value)
except (TypeError, ValueError):
return default
def _safe_timedelta(value, default=0):
if value is None:
return default
try:
return float(value.total_seconds() / 86400)
except (TypeError, AttributeError):
return default
return BacktestResult(
code=code,
equity_final=_safe_float(stats.get("Equity Final [$]"), 0),
equity_peak=_safe_float(stats.get("Equity Peak [$]"), 0),
return_pct=_safe_float(stats.get("Return [%]"), 0),
buy_hold_return_pct=_safe_float(stats.get("Buy & Hold Return [%]"), 0),
return_ann_pct=_safe_float(stats.get("Return (Ann.) [%]"), 0),
volatility_ann_pct=_safe_float(stats.get("Volatility (Ann.) [%]"), 0),
sortino_ratio=_safe_float(stats.get("Sortino Ratio"), 0),
calmar_ratio=_safe_float(stats.get("Calmar Ratio"), 0),
max_drawdown_pct=_safe_float(stats.get("Max. Drawdown [%]"), 0),
avg_drawdown_pct=_safe_float(stats.get("Avg. Drawdown [%]"), 0),
max_drawdown_duration=_safe_timedelta(stats.get("Max. Drawdown Duration"), 0),
avg_drawdown_duration=_safe_timedelta(stats.get("Avg. Drawdown Duration"), 0),
num_trades=_safe_int(stats.get("# Trades"), 0),
win_rate_pct=_safe_float(stats.get("Win Rate [%]"), 0),
sqn=_safe_float(stats.get("SQN"), 0),
)
def run_batch_backtest(
codes: list[str],
start_date: str,
end_date: str,
strategy_file: str,
cash: float = config.DEFAULT_CASH,
commission: float = config.DEFAULT_COMMISSION,
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
output_dir: Optional[str] = None,
show_progress: bool = True,
) -> list[BacktestResult]:
results = []
codes_iter = tqdm(codes, desc="批量回测") if show_progress else codes
for code in codes_iter:
result = run_backtest(
code=code,
start_date=start_date,
end_date=end_date,
strategy_file=strategy_file,
cash=cash,
commission=commission,
warmup_days=warmup_days,
output_dir=output_dir,
)
results.append(result)
return results

20
config.py Normal file
View File

@@ -0,0 +1,20 @@
"""
配置文件
集中管理数据库配置、默认回测参数、图表配色
"""
DB_HOST = "81.71.3.24"
DB_PORT = 6785
DB_NAME = "leopard_dev"
DB_USER = "leopard"
DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X"
DEFAULT_CASH = 100000
DEFAULT_COMMISSION = 0.002
DEFAULT_WARMUP_DAYS = 365
from bokeh.colors.named import tomato, lime
BULL_COLOR = tomato
BEAR_COLOR = lime

295
note_refactor.md Normal file
View File

@@ -0,0 +1,295 @@
# 回测代码重构说明
## 概述
本次重构将原有的单一文件 `backtest.py` 拆分为模块化架构,提升代码复用性和可维护性。
## 文件结构变化
### 新增文件
1. **config.py** - 配置管理模块
- 数据库配置DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
- 默认回测参数DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS
- 图表配色BULL_COLOR, BEAR_COLOR
2. **backtest_core.py** - 核心回测引擎
- `BacktestResult` 数据类:结构化回测结果
- `load_data_from_db()`:从数据库加载历史数据
- `load_strategy()`:动态加载策略文件
- `apply_color_scheme()`:应用图表配色
- `run_backtest()`:单股票回测函数
- `run_batch_backtest()`:批量回测函数(串行执行)
3. **backtest_command.py** - 命令行界面
- `parse_arguments()`:解析命令行参数
- `format_single_result()`:详细格式输出(单股票)
- `format_batch_results()`:表格格式输出(多股票,使用 tabulate
- `main()`:主流程编排
### 删除文件
1. **backtest.py** - 原有单一文件284 行)
## 接口变化
### 新增 API
```python
# 单股票回测
result = backtest_core.run_backtest(
code='000001.SZ',
start_date='2024-01-01',
end_date='2024-12-31',
strategy_file='strategies/sma_strategy.py',
cash=100000,
commission=0.002,
warmup_days=365,
output_dir=None # 可选,为 None 时不生成图表
)
# 批量回测
results = backtest_core.run_batch_backtest(
codes=['000001.SZ', '600000.SH'],
start_date='2024-01-01',
end_date='2024-12-31',
strategy_file='strategies/sma_strategy.py',
cash=100000,
commission=0.002,
warmup_days=365,
output_dir='output/', # 可选,为每个股票生成 {code}.html
show_progress=True # 可选,是否显示 tqdm 进度条
)
```
### 新增数据结构
```python
@dataclasses.dataclass
class BacktestResult:
code: str
equity_final: float
equity_peak: float
return_pct: float
buy_hold_return_pct: float
return_ann_pct: float
volatility_ann_pct: float
sortino_ratio: float
calmar_ratio: float
max_drawdown_pct: float
avg_drawdown_pct: float
max_drawdown_duration: float
avg_drawdown_duration: float
num_trades: int
win_rate_pct: float
sqn: float
```
## 命令行使用方式变化
### 旧方式(已删除)
```bash
python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2024-12-31 --strategy-file strategy.py
```
### 新方式
```bash
uv run python backtest_command.py --codes 000001.SZ --start-date 2024-01-01 --end-date 2024-12-31 --strategy-file strategies/sma_strategy.py
```
### 参数变化
| 参数名 | 变化 | 说明 |
|--------|--------|------|
| `--code` | 改为 `--codes` | 从单一参数改为多值参数(`nargs='+'` |
| `--output` | 改为 `--output-dir` | 指定目录而非文件路径 |
### 新增参数
- `--output-dir`:指定图表输出目录(可选)
- 单股票时:生成 `{code}.html` 在指定目录
- 多股票时:为每个股票生成 `{code}.html` 在指定目录
- 不指定时不生成图表
## 输出格式变化
### 单股票输出
保持原有的详细格式输出,每个指标单独一行:
```
============================================================
股票代码: 000001.SZ
============================================================
最终收益: 100981.58
峰值收益: 103731.54
总收益率(%: 0.98
...
============================================================
```
### 多股票输出
新增表格格式输出(使用 tabulategrid 格式):
```
+------------+-----------+---------+-------------+------------+-------+
| 股票代码 | 收益率% | 胜率% | 最大回撤% | 交易次数 | SQN |
+============+===========+=========+=============+============+=======+
| 000001.SZ | 0.98 | 100 | -2.65 | 1 | nan |
| 600000.SH | 0.04 | 100 | -1.5 | 1 | nan |
+------------+-----------+---------+-------------+------------+-------+
```
### 进度条
多股票回测时显示 tqdm 进度条:
```
批量回测: 50%|█████ | 1/2 [00:07<00:07, 7.82s/it]
```
## 依赖变化
### 新增依赖
- `tabulate`:表格格式化
- 版本0.9.0
- 用途:批量回测结果的表格化输出
- `tqdm`:进度条显示
- 版本4.67.1
- 用途:批量回测时的实时进度反馈
## 特性增强
### 新增功能
1. **批量回测**:支持传入多个股票代码进行串行回测
- 命令:`--codes 000001.SZ 600000.SH`
- 输出:表格化结果对比
- 进度条:实时显示回测进度
2. **图表生成**:为每个股票生成独立 HTML 图表
- 参数:`--output-dir output/`
- 输出:`{code}.html` 在指定目录
- 自动创建目录:`os.makedirs(output_dir, exist_ok=True)`
3. **进度条显示**:使用 tqdm 提供实时反馈
- 多股票时自动显示
- 可通过 `show_progress=False` 禁用
## 兼容性说明
### BREAKING CHANGES
1. **命令行入口变化**
- 旧:`python backtest.py`
- 新:`uv run python backtest_command.py`
2. **参数名称变化**
- `--code``--codes`(从单值改为多值)
### 兼容性保证
- 所有原有功能完整保留
- 核心回测逻辑无变化
- 策略加载方式不变
- 数据访问接口不变
## 代码行数对比
| 文件 | 旧行数 | 新行数 | 变化 |
|------|---------|---------|------|
| backtest.py | 284 | - | -284 |
| config.py | - | 20 | +20 |
| backtest_core.py | - | ~200 | +200 |
| backtest_command.py | - | ~120 | +120 |
| **总计** | **284** | **~340** | **+56** |
## 迁移指南
### 对于开发者
如果需要在其他模块中调用回测功能:
```python
from backtest_core import run_backtest, run_batch_backtest, BacktestResult
# 单股票回测
result = run_backtest(
code='000001.SZ',
start_date='2024-01-01',
end_date='2024-12-31',
strategy_file='strategies/sma_strategy.py'
)
# 批量回测
results = run_batch_backtest(
codes=['000001.SZ', '600000.SH'],
start_date='2024-01-01',
end_date='2024-12-31',
strategy_file='strategies/sma_strategy.py'
)
# 访问结果
print(result.return_pct)
print(result.win_rate_pct)
```
### 对于终端用户
**单股票回测示例:**
```bash
uv run python backtest_command.py \
--codes 000001.SZ \
--start-date 2024-01-01 \
--end-date 2024-12-31 \
--strategy-file strategies/sma_strategy.py
```
**多股票回测示例:**
```bash
uv run python backtest_command.py \
--codes 000001.SZ 600000.SH \
--start-date 2024-01-01 \
--end-date 2024-12-31 \
--strategy-file strategies/sma_strategy.py
```
**生成图表示例:**
```bash
uv run python backtest_command.py \
--codes 000001.SZ \
--start-date 2024-01-01 \
--end-date 2024-12-31 \
--strategy-file strategies/sma_strategy.py \
--output-dir output/
```
## 错误处理
- **立即失败策略**:遇到第一个错误立即停止,不继续执行其他股票
- **友好错误提示**:捕获异常并打印清晰的错误信息
- **退出状态码**:成功返回 0失败返回非零
- **回溯信息**:打印完整的堆栈跟踪以便调试
## 性能考虑
- **串行执行**:当前采用串行执行,确保简单可靠
- **未来扩展**未来可改为并行执行ThreadPoolExecutor以提升性能
- **数据加载**:每次回测创建独立的数据库连接,避免连接池复杂度
## 总结
本次重构实现了:
- ✅ 代码模块化:核心逻辑与 CLI 界面分离
- ✅ 可复用性:提供标准化 API 供其他模块调用
- ✅ 功能增强:支持批量回测和图表生成
- ✅ 用户体验:表格化结果和进度条显示
- ✅ 代码质量:更清晰的模块划分和类型提示

View File

@@ -0,0 +1,2 @@
schema: spec-driven
created: 2026-01-28

View File

@@ -0,0 +1,287 @@
## Context
**Current State:**
`backtest.py` (284 lines) 是单一文件,包含:
- 命令行参数解析
- 数据库连接和数据加载
- 策略动态加载和验证
- 回测执行逻辑
- 结果格式化输出
- 图表生成
**Constraints:**
- 需要保持现有功能完整性(数据加载、策略加载、回测执行、结果展示)
- 需要支持多股票回测(串行执行)
- 不考虑并发实现(保持简单)
- 错误处理采用立即失败策略
- 数据库配置明文存储,不考虑环境变量
**Stakeholders:**
- 开发者:需要清晰的模块划分和可复用的接口
- 终端用户:需要友好的 CLI 输出(进度条、表格化结果)
## Goals / Non-Goals
**Goals:**
1. 分离核心逻辑与 CLI 界面,提升代码复用性
2. 提供标准化函数接口,供其他模块调用回测功能
3. 支持多股票批量回测(串行执行)
4. 集中管理配置(数据库、参数、配色)
5. 优化 CLI 输出体验tabulate 表格化、tqdm 进度条)
**Non-Goals:**
1. 并行执行多股票回测(性能优化非目标)
2. 环境变量管理配置(配置明文存储即可)
3. 复杂的聚合统计(仅单股票结果拼接)
4. 图表文件合并(每个股票生成独立 HTML
5. 配置文件热重载(启动时加载一次)
## Decisions
### Decision 1: 三层模块划分
**选择:** 分离为 `config.py``backtest_core.py``backtest_command.py` 三个文件
**理由:**
- **config.py**:集中管理所有配置,避免硬编码分散
- **backtest_core.py**:纯粹的业务逻辑,提供可复用的函数接口
- **backtest_command.py**CLI 界面,负责参数解析和结果展示
**替代方案:**
- 方案 A保留单一文件但改进内部结构函数分离
- 拒绝理由仍无法复用CLI 和业务逻辑耦合
- 方案 B使用类封装`BacktestEngine` 类)
- 拒绝理由:增加复杂度,函数接口已足够
### Decision 2: BacktestResult 数据类
**选择:** 使用 `dataclasses.dataclass` 定义 `BacktestResult`
**理由:**
- 结构化返回结果,便于序列化和导出
- 类型提示支持,提升代码可读性
- 自动生成 `__init__``__repr__` 等方法,减少样板代码
**替代方案:**
- 方案 A直接返回原始 `stats` 对象backtesting 库返回)
- 拒绝理由:依赖 backtesting 库内部结构,耦合度高
- 方案 B返回字典
- 拒绝理由:缺乏类型提示,容易拼写错误
### Decision 3: 批量回测策略
**选择:** 串行执行(`for` 循环),立即失败
**理由:**
- 简单可靠,易于调试
- 错误处理清晰(第一个失败就停止)
- 避免并发带来的资源竞争和复杂度
**替代方案:**
- 方案 A并行执行ThreadPoolExecutor
- 拒绝理由:性能非目标,并发增加复杂度
- 方案 B继续执行其他股票最后统一报告错误
- 拒绝理由:用户需求是立即失败
### Decision 4: CLI 参数设计
**选择:** `--codes` 多值参数(`nargs='+'``--output-dir` 目录参数
**理由:**
- `--codes` 支持传入多个股票代码,如 `--codes 000001.SZ 600000.SH`
- `--output-dir` 为每个股票生成 `{code}.html`,如 `output/000001.SZ.html`
- 保持原有参数(`--start-date``--end-date``--strategy-file``--cash``--commission``--warmup-days`
**替代方案:**
- 方案 A`--code` 逗号分隔(如 `--code 000001.SZ,600000.SH`
- 拒绝理由:需要额外解析逻辑,不直观
- 方案 B`--code` 多次调用(如 `--code 000001.SZ --code 600000.SH`
- 拒绝理由argparse 的 `nargs='+'` 更符合习惯
### Decision 5: 输出优化库
**选择:** 使用 `tabulate` 表格化批量结果,使用 `tqdm` 显示进度条
**理由:**
- **tabulate**提供美观的表格输出支持多种格式grid、simple 等)
- **tqdm**:提供实时进度条,提升用户体验
- 两个库都是轻量级,不引入复杂依赖
**替代方案:**
- 方案 A手动格式化表格字符串拼接
- 拒绝理由:代码冗余,格式不够美观
- 方案 B不使用进度条仅输出完成提示
- 拒绝理由:多股票回测耗时较长,用户需要进度反馈
### Decision 6: 结果展示策略
**选择:** 单股票使用详细格式(现有),多股票使用表格格式(新增)
**理由:**
- 单股票:保持原有的详细输出(每个指标单独一行)
- 多股票:使用 `tabulate` 表格横向对比,节省垂直空间
**替代方案:**
- 方案 A所有情况都使用详细格式拼接
- 拒绝理由:多股票时输出过长,难以阅读
- 方案 B所有情况都使用表格格式
- 拒绝理由:单股票时表格优势不明显,详细格式更清晰
### Decision 7: 配置管理方式
**选择:** 明文常量存储在 `config.py`
**理由:**
- 满足用户需求(不考虑信息安全)
- 避免引入 `python-dotenv` 依赖
- 代码简洁,修改直接
**替代方案:**
- 方案 A环境变量`os.getenv`
- 拒绝理由:用户明确不需要
- 方案 B配置文件JSON/YAML
- 拒绝理由:增加文件管理和解析复杂度
### Decision 8: 数据访问接口
**选择:** `load_data_from_db(code, start_date, end_date)` 函数签名保持不变
**理由:**
- 现有接口已满足需求(单次查询一个股票)
- 迁移成本低,直接复制到 `backtest_core.py`
**替代方案:**
- 方案 A批量查询`load_data_from_db(codes, start_date, end_date)`
- 拒绝理由:需要修改 SQL 为 `IN` 子句,且结果聚合复杂
- 方案 B连接池复用全局 engine 对象)
- 拒绝理由:每次创建引擎的开销可接受(串行执行)
### Decision 9: 策略加载接口
**选择:** `load_strategy(strategy_file)` 返回 `(calculate_indicators, strategy_class)` 元组
**理由:**
- 保持现有接口,迁移成本低
- 函数返回两个值符合 Python 惯例
**替代方案:**
- 方案 A返回类对象策略类自带指标计算方法
- 拒绝理由:现有策略文件结构分离了两者,修改成本高
- 方案 B返回命名空间对象封装两个属性
- 拒绝理由:增加复杂度,元组足够
### Decision 10: 错误处理策略
**选择:** 立即失败(不捕获部分错误继续执行)
**理由:**
- 符合用户需求
- 简化错误追踪(第一个错误直接暴露)
- 避免"部分成功"的歧义状态
**替代方案:**
- 方案 A捕获错误但继续执行最后统一报告
- 拒绝理由:用户明确要求立即失败
## Risks / Trade-offs
### Risk 1: CLI 命令变化导致用户习惯中断
**风险:** 用户习惯使用 `python backtest.py`,需要切换到 `uv run python backtest_command.py`
**缓解:**
- 在项目根目录创建软链接 `backtest.py -> backtest_command.py`(可选)
- 或在 README 中明确说明新的使用方式
- 提供迁移指南(参数变化说明)
### Risk 2: 多股票串行执行耗时较长
**风险:** 10 个股票可能需要 10 倍时间(每个 30 秒 → 总计 5 分钟)
**缓解:**
- 使用 `tqdm` 进度条提供实时反馈
- 在 README 中说明性能限制
- 未来可扩展为并行执行(非当前目标)
### Risk 3: BacktestResult 字段可能与 backtesting 库不兼容
**风险:** backtesting 库升级后stats 对象的键名可能变化
**缓解:**
- 使用 `.get(key, default)` 方法访问,避免 KeyError
- 提供默认值0 或空字符串)
- 在文档中说明依赖的 backtesting 版本
### Risk 4: tabulate/tqdm 依赖未安装
**风险:** 用户运行时缺少依赖,导致 ImportError
**缓解:**
- 使用 `uv add` 明确添加依赖到 pyproject.toml
- 在 README 中说明安装步骤
- 错误信息中提示安装命令(`uv add tabulate tqdm`
### Risk 5: 策略文件路径处理不一致
**风险:** 策略文件路径可能是相对路径或绝对路径,导致加载失败
**缓解:**
- 使用 `os.path.abspath()` 转换为绝对路径
- 在错误信息中提示用户检查路径
- 测试相对路径和绝对路径两种情况
### Risk 6: 图表输出目录不存在
**风险:** 用户指定的 `--output-dir` 不存在,导致保存失败
**缓解:**
- 使用 `os.makedirs(output_dir, exist_ok=True)` 自动创建
- 在错误信息中提示用户检查目录权限
### Risk 7: 内存占用(多股票同时加载数据)
**风险:** 如果同时加载多个股票数据,内存占用可能较高
**缓解:**
- 串行执行确保一次只加载一个股票的数据
- 单个股票的数据量可控10 年约几 MB
- future 可考虑流式处理(非当前目标)
## Migration Plan
### Step 1: 创建 config.py
1.`backtest.py` 提取数据库配置
2. 添加默认回测参数
3. 添加图表配色配置
4. 测试导入无错误
### Step 2: 创建 backtest_core.py
1. 迁移 `load_data_from_db()` 函数(导入 config
2. 迁移 `load_strategy()` 函数
3. 迁移 `apply_color_scheme()` 函数(使用 config 配置)
4. 定义 `BacktestResult` 数据类
5. 实现 `run_backtest()` 函数
6. 实现 `run_batch_backtest()` 函数
7. 单元测试核心函数
### Step 3: 创建 backtest_command.py
1. 实现 `parse_arguments()` 函数(支持 `--codes`
2. 实现 `format_single_result()` 函数(详细格式)
3. 实现 `format_batch_results()` 函数(使用 tabulate
4. 实现 `main()` 函数(调用 `run_batch_backtest()`
5. 测试单股票回测
6. 测试多股票回测
### Step 4: 更新依赖
1. 运行 `uv add tabulate` 添加依赖
2. 运行 `uv add tqdm` 添加依赖
3. 运行 `uv sync` 同步依赖
### Step 5: 删除 backtest.py
1. 确认新功能完整(单股票、多股票、图表输出)
2. 确认错误处理正确(立即失败)
3. 删除 `backtest.py` 文件
4. 更新 README 说明新的使用方式
### Rollback Strategy
如果迁移过程中发现问题:
1. 保留 `backtest.py` 直到 `backtest_command.py` 完全可用
2. 使用 `git` 版本控制,可随时回退
3. 逐步迁移(先核心函数,后 CLI确保每步可验证
## Open Questions
1. **BacktestResult 字段完整性:** 是否需要包含所有 backtesting.stats 的键,或仅包含当前用到的字段?
- 倾向:仅包含当前用到的字段(未来可扩展)
2. **表格格式选择:** tabulate 支持多种格式grid、simple、pipe、html多股票结果使用哪种
- 倾向grid美观的边框格式
3. **进度条粒度:** tqdm 进度条应该显示每个股票的回测进度,还是仅显示批量回测的总进度?
- 倾向:仅显示批量回测的总进度(股票 N/M
4. **图表输出目录结构:** 多股票图表是平铺在 `output/` 下,还是按日期/策略分组?
- 倾向:平铺在 `output/` 下(简单)

View File

@@ -0,0 +1,54 @@
## Why
当前 `backtest.py` 存在职责混杂的问题:命令行参数解析、核心回测逻辑、数据访问、结果展示都耦合在单一文件中,导致:
- 难以在其他模块中复用回测功能
- 无法进行单元测试
- 仅支持单股票回测,无法批量处理
需要重构为分层架构,将核心逻辑与 CLI 界面分离,提升代码复用性和可维护性。
## What Changes
- **创建 `config.py`**:集中管理数据库配置、默认回测参数、图表配色
- **创建 `backtest_core.py`**:核心回测引擎
- 提供标准化接口 `run_backtest()`(单股票)
- 提供批量接口 `run_batch_backtest()`(多股票,串行执行)
- 封装数据访问和策略加载逻辑
- 返回结构化结果对象 `BacktestResult`
- **创建 `backtest_command.py`**:命令行界面
- 支持多股票代码参数 `--codes`(接受多个值)
- 使用 `tabulate` 优化批量结果的表格展示
- 使用 `tqdm` 显示批量回测进度条
- 保留原有的单股票详细输出格式
- **删除 `backtest.py`**:不再需要,功能已迁移
- **依赖更新**:添加 `tabulate``tqdm` 到项目依赖
## Capabilities
### New Capabilities
- `batch-backtest`: 批量回测功能,支持传入多个股票代码进行串行回测,并提供进度条和表格化结果展示
### Modified Capabilities
- 无(其他均为实现重构,不改变 spec 级别行为)
## Impact
- **代码影响**
- 删除 `backtest.py`284 行)
- 新增 `config.py`(约 30 行)
- 新增 `backtest_core.py`(约 250 行)
- 新增 `backtest_command.py`(约 150 行)
- **API 变化**
- 新增 `run_backtest(code, start_date, end_date, strategy_file, ...)` 函数
- 新增 `run_batch_backtest(codes, start_date, end_date, strategy_file, ...)` 函数
- 新增 `BacktestResult` 数据类
- **命令行变化**
- 单参数 `--code` 改为多值参数 `--codes`
- 新增 `--output-dir` 参数,为每个股票生成独立 HTML 图表
- 批量回测时显示进度条和表格化结果
- **依赖变化**
- 新增 `tabulate`(表格格式化)
- 新增 `tqdm`(进度条显示)
- **兼容性**
- **BREAKING**: 删除原有 `backtest.py`,命令行使用方式从 `python backtest.py` 改为 `uv run python backtest_command.py`
- 参数名称从 `--code` 改为 `--codes`

View File

@@ -0,0 +1,310 @@
# Spec: Batch Backtest
## ADDED Requirements
### Requirement: 多股票回测参数
系统 SHALL 支持通过命令行参数传入多个股票代码进行批量回测。
#### Scenario: 传入多个股票代码
- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ 600000.SH --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py`
- **THEN** 系统解析所有股票代码到列表 `['000001.SZ', '600000.SH']`
- **THEN** 系统按顺序依次执行每个股票的回测
- **THEN** 系统为每个股票生成独立的回测结果
#### Scenario: 传入单个股票代码
- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py`
- **THEN** 系统解析为单个股票代码列表 `['000001.SZ']`
- **THEN** 系统执行单个股票回测
- **THEN** 系统输出详细格式的回测结果
#### Scenario: 缺少 --codes 参数
- **WHEN** 用户未提供 `--codes` 参数
- **THEN** 系统输出错误信息:"错误: 需要以下参数: --codes"
- **THEN** 系统退出并返回非零状态码
---
### Requirement: 批量回测执行
系统 SHALL 串行执行多个股票的回测,每次加载一个股票的数据并执行回测。
#### Scenario: 成功执行多个股票回测
- **WHEN** 用户传入 N 个股票代码
- **THEN** 系统循环 N 次,每次加载一个股票的数据
- **THEN** 系统每次执行完整的回测流程(数据加载、指标计算、回测执行)
- **THEN** 系统每次执行完成后生成 `BacktestResult` 对象
- **THEN** 系统返回包含 N 个 `BacktestResult` 的列表
#### Scenario: 每个股票独立预热期
- **WHEN** 系统执行第 i 个股票的回测
- **THEN** 系统使用 `start_date - warmup_days` 计算该股票的预热开始日期
- **THEN** 系统独立加载该股票的预热期数据
- **THEN** 不同股票的预热期互不影响
#### Scenario: 第一个股票回测失败
- **WHEN** 系统执行第一个股票回测时发生错误(数据库连接失败、策略加载失败等)
- **THEN** 系统捕获异常并输出错误信息
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码(立即失败策略)
#### Scenario: 中间股票回测失败
- **WHEN** 系统执行第 i 个股票回测时发生错误
- **THEN** 系统输出错误信息(包含股票代码)
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 资源管理
- **WHEN** 系统完成第 i 个股票的回测
- **THEN** 系统关闭该股票的数据库连接(`engine.dispose()`
- **THEN** 系统释放该股票的数据内存
- **THEN** 系统开始加载第 i+1 个股票的数据
---
### Requirement: 批量回测进度显示
系统 SHALL 使用 tqdm 显示批量回测的实时进度,提供用户反馈。
#### Scenario: 显示进度条
- **WHEN** 系统开始执行 N 个股票的批量回测
- **THEN** 系统显示进度条格式:`回测进度: 25%|█████▌ | 1/4 [00:30<01:30, 12.5s/it]`
- **THEN** 系统在完成每个股票回测后更新进度条
- **THEN** 进度条显示当前进度i/N、已用时间、预计剩余时间
- **THEN** 进度条在所有股票回测完成后消失
#### Scenario: 单股票回测不显示进度条
- **WHEN** 用户传入单个股票代码
- **THEN** 系统不显示 tqdm 进度条
- **THEN** 系统直接输出回测结果
#### Scenario: 进度条描述文本
- **WHEN** 系统显示批量回测进度
- **THEN** 进度条描述 SHALL 为 "回测进度"(中文)
- **THEN** 进度条显示已完成/总数(如 "1/4", "2/4"
---
### Requirement: 批量回测结果展示
系统 SHALL 使用 tabulate 将多个股票的回测结果格式化为表格,便于横向对比。
#### Scenario: 表格化输出多股票结果
- **WHEN** 用户传入多个股票代码且回测成功
- **THEN** 系统使用 tabulate 生成表格
- **THEN** 表格格式 SHALL 为 grid带边框
- **THEN** 表格列 SHALL 包含:股票代码、收益率%、胜率%、最大回撤%、交易次数、SQN
- **THEN** 系统在表格上方显示表头(中文列名)
- **THEN** 数值保留 2 位小数(交易次数为整数)
#### Scenario: 表格内容填充
- **WHEN** 系统格式化第 i 个股票的结果
- **THEN** 系统从 `BacktestResult` 对象提取字段
- **THEN** "股票代码" 列填充 `result.code`
- **THEN** "收益率%" 列填充 `result.return_pct`
- **THEN** "胜率%" 列填充 `result.win_rate`
- **THEN** "最大回撤%" 列填充 `result.max_drawdown`
- **THEN** "交易次数" 列填充 `result.trades`
- **THEN** "SQN" 列填充 `result.sqn`
#### Scenario: 单股票回测不使用表格
- **WHEN** 用户传入单个股票代码
- **THEN** 系统不使用 tabulate 生成表格
- **THEN** 系统使用详细格式输出(每个指标单独一行)
- **THEN** 系统保持原有 `print_stats()` 的输出格式
#### Scenario: 表格示例输出
- **WHEN** 用户传入 2 个股票代码
- **THEN** 系统输出格式 SHALL 为:
```
+-------------+-----------+--------+------------+----------+-------+
| 股票代码 | 收益率% | 胜率% | 最大回撤% | 交易次数 | SQN |
+-------------+-----------+--------+------------+----------+-------+
| 000001.SZ | 20.35 | 55.00 | -8.50 | 45 | 1.85 |
| 600000.SH | 15.00 | 48.00 | -12.30 | 38 | 1.42 |
+-------------+-----------+--------+------------+----------+-------+
```
---
### Requirement: 多股票图表输出
系统 SHALL 为每个股票生成独立的 HTML 图表文件,文件名格式为 `{code}.html`。
#### Scenario: 指定 --output-dir 参数
- **WHEN** 用户传入 `--output-dir output/`
- **THEN** 系统为每个股票生成 HTML 文件到 `output/{code}.html`
- **THEN** 文件名 SHALL 为股票代码,如 `000001.SZ.html`, `600000.SH.html`
- **THEN** 系统自动创建 `output/` 目录(`exist_ok=True`
- **THEN** 系统在完成后输出提示:"图表已保存到目录: output/" 后列出所有文件
#### Scenario: 未指定 --output-dir 参数
- **WHEN** 用户未传入 `--output-dir` 参数
- **THEN** 系统不为任何股票生成图表文件
- **THEN** 系统仅输出控制台统计信息
#### Scenario: 图表文件覆盖
- **WHEN** 系统再次执行相同的批量回测
- **THEN** 系统覆盖已存在的 HTML 文件
- **THEN** 系统不提示文件已存在
---
### Requirement: 结构化回测结果
系统 SHALL 返回标准化的 `BacktestResult` 对象,包含所有关键指标。
#### Scenario: BacktestResult 对象创建
- **WHEN** 系统完成单股票回测
- **THEN** 系统从 `stats` 对象提取指标到 `BacktestResult`
- **THEN** `BacktestResult.code` 设置为股票代码
- **THEN** `BacktestResult.start_date` 设置为回测开始日期
- **THEN** `BacktestResult.end_date` 设置为回测结束日期
- **THEN** `BacktestResult.equity_final` 设置为最终权益
- **THEN** `BacktestResult.equity_peak` 设置为峰值收益
- **THEN** `BacktestResult.return_pct` 设置为总收益率
- **THEN** `BacktestResult.buy_hold_return` 设置为买入持有收益率
- **THEN** `BacktestResult.return_annual` 设置为年化收益率
- **THEN** `BacktestResult.volatility_annual` 设置为年化波动率
- **THEN** `BacktestResult.max_drawdown` 设置为最大回撤
- **THEN** `BacktestResult.avg_drawdown` 设置为平均回撤
- **THEN** `BacktestResult.max_drawdown_duration` 设置为最大回撤持续时长
- **THEN** `BacktestResult.avg_drawdown_duration` 设置为平均回撤持续时长
- **THEN** `BacktestResult.sortino_ratio` 设置为索提诺比率
- **THEN** `BacktestResult.calmar_ratio` 设置为卡尔玛比率
- **THEN** `BacktestResult.trades` 设置为交易次数
- **THEN** `BacktestResult.win_rate` 设置为胜率
- **THEN** `BacktestResult.sqn` 设置为系统质量数
- **THEN** `BacktestResult.cash` 设置为初始资金
- **THEN** `BacktestResult.commission` 设置为手续费率
#### Scenario: BacktestResult 列表返回
- **WHEN** 系统完成批量回测
- **THEN** 系统返回 `List[BacktestResult]`
- **THEN** 列表顺序 SHALL 与输入股票代码顺序一致
- **THEN** 列表长度 SHALL 等于输入股票代码数量(成功时)
#### Scenario: BacktestResult 数据类型
- **WHEN** 系统创建 `BacktestResult` 对象
- **THEN** 数值字段 SHALL 为 float 类型(除 `trades`, `max_drawdown_duration` 为 int
- **THEN** 日期字段 SHALL 为 str 类型YYYY-MM-DD 格式)
- **THEN** 系统支持 `result.to_dict()` 方法dataclass 自动生成)
---
### Requirement: 可复用回测引擎接口
系统 SHALL 提供标准化的函数接口,供其他模块调用回测功能。
#### Scenario: run_backtest 函数调用
- **WHEN** 其他模块调用 `run_backtest(code, start_date, end_date, strategy_file, cash, commission, warmup_days, output_file)`
- **THEN** 函数接收股票代码、日期范围、策略文件、回测参数、输出文件路径
- **THEN** 函数执行完整回测流程(数据加载、策略加载、指标计算、回测执行)
- **THEN** 函数返回 `BacktestResult` 对象
- **THEN** 函数不打印任何输出(纯函数)
#### Scenario: run_batch_backtest 函数调用
- **WHEN** 其他模块调用 `run_batch_backtest(codes, start_date, end_date, strategy_file, cash, commission, warmup_days, output_dir)`
- **THEN** 函数接收股票代码列表、日期范围、策略文件、回测参数、输出目录
- **THEN** 函数串行执行每个股票的回测
- **THEN** 函数返回 `List[BacktestResult]`
- **THEN** 函数显示 tqdm 进度条(批量时)
#### Scenario: 函数参数默认值
- **WHEN** 调用者不指定可选参数
- **THEN** `cash` 默认为 100000
- **THEN** `commission` 默认为 0.002
- **THEN** `warmup_days` 默认为 365
- **THEN** `output_file` 默认为 None不生成图表
- **THEN** `output_dir` 默认为 None不生成图表
#### Scenario: 函数异常抛出
- **WHEN** `run_backtest` 或 `run_batch_backtest` 执行时发生错误
- **THEN** 函数 SHALL 抛出异常(不捕获)
- **THEN** 异常类型 SHALL 为 ValueError、TypeError 或原始异常
- **THEN** 异常信息 SHALL 包含具体错误原因
- **THEN** 调用者负责捕获和处理异常
---
### Requirement: 集中配置管理
系统 SHALL 在 config.py 中集中管理数据库配置、默认回测参数、图表配色。
#### Scenario: 数据库配置访问
- **WHEN** backtest_core.py 需要数据库连接参数
- **THEN** 模块从 config 导入 `DB_HOST`, `DB_PORT`, `DB_NAME`, `DB_USER`, `DB_PASSWORD`
- **THEN** 模块使用这些常量构建连接字符串
- **THEN** 模块不重复定义数据库配置
#### Scenario: 默认参数访问
- **WHEN** backtest_core.py 需要默认回测参数
- **THEN** 模块从 config 导入 `DEFAULT_CASH`, `DEFAULT_COMMISSION`, `DEFAULT_WARMUP_DAYS`
- **THEN** 模块使用这些常量作为函数默认值
- **THEN** 模块不重复定义默认参数
#### Scenario: 图表配色访问
- **WHEN** backtest_core.py 需要设置图表配色
- **THEN** 模块从 config 导入 `BULL_COLOR`, `BEAR_COLOR`
- **THEN** 模块使用这些颜色设置 `plotting.BULL_COLOR` 和 `plotting.BEAR_COLOR`
- **THEN** 模块不重复定义颜色配置
#### Scenario: 配置文件内容
- **WHEN** 查看 config.py 文件
- **THEN** 文件包含数据库配置DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
- **THEN** 文件包含默认回测参数DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS
- **THEN** 文件包含图表配色BULL_COLOR, BEAR_COLOR
- **THEN** 所有配置使用明文常量(不使用环境变量)
---
### Requirement: 错误处理策略
系统 SHALL 在批量回测失败时立即停止执行,不继续处理后续股票。
#### Scenario: 数据加载失败
- **WHEN** 系统加载第 i 个股票数据时失败(数据库错误、数据不存在)
- **THEN** 系统捕获异常
- **THEN** 系统输出错误信息:"回测失败 [{code}]: {error}"
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 策略加载失败
- **WHEN** 系统加载策略文件时失败(文件不存在、接口不完整)
- **THEN** 系统捕获异常
- **THEN** 系统输出错误信息:"策略加载失败: {error}"
- **THEN** 系统停止执行所有股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 回测执行失败
- **WHEN** 系统执行第 i 个股票回测时失败(策略逻辑错误)
- **THEN** 系统捕获异常
- **THEN** 系统输出错误信息和完整堆栈跟踪
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 图表生成失败
- **WHEN** 系统生成第 i 个股票图表时失败
- **THEN** 系统捕获异常
- **THEN** 系统输出警告:"图表生成失败 [{code}]: {error},但回测已完成"
- **THEN** 系统继续执行后续股票的回测
- **THEN** 系统在返回的 `BacktestResult` 中设置 `error` 字段(如果设计支持)
---
### Requirement: 依赖管理
系统 SHALL 在 pyproject.toml 中添加 tabulate 和 tqdm 依赖。
#### Scenario: 添加 tabulate 依赖
- **WHEN** 查看 pyproject.toml 文件
- **THEN** 文件包含 `tabulate` 依赖
- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本
- **THEN** 系统可以导入 `import tabulate` 无错误
#### Scenario: 添加 tqdm 依赖
- **WHEN** 查看 pyproject.toml 文件
- **THEN** 文件包含 `tqdm` 依赖
- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本
- **THEN** 系统可以导入 `from tqdm import tqdm` 无错误
#### Scenario: 依赖安装
- **WHEN** 用户运行 `uv sync` 或 `pip install -e .`
- **THEN** 系统自动安装 tabulate 和 tqdm
- **THEN** 系统显示依赖安装进度
- **THEN** 系统完成安装后可以正常使用回测工具
#### Scenario: 依赖缺失提示
- **WHEN** 系统导入 tabulate 或 tqdm 时失败
- **THEN** 系统输出友好错误信息:"缺少依赖: {package_name},请运行: uv add {package_name}"
- **THEN** 系统退出并返回非零状态码

View File

@@ -0,0 +1,96 @@
## 1. 依赖管理
- [x] 1.1 运行 `uv add tabulate` 添加依赖
- [x] 1.2 运行 `uv add tqdm` 添加依赖
- [x] 1.3 运行 `uv sync` 同步依赖
## 2. 配置管理模块
- [x] 2.1 创建 config.py 文件
- [x] 2.2 在 config.py 中定义数据库配置常量DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
- [x] 2.3 在 config.py 中定义默认回测参数DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS
- [x] 2.4 在 config.py 中定义图表配色BULL_COLOR, BEAR_COLOR
- [x] 2.5 测试 config.py 导入无错误
## 3. 核心回测引擎
- [x] 3.1 创建 backtest_core.py 文件
- [x] 3.2 在 backtest_core.py 中导入必要模块和 config
- [x] 3.3 定义 BacktestResult dataclass包含所有回测指标字段
- [x] 3.4 迁移 load_data_from_db() 函数(使用 config 数据库配置)
- [x] 3.5 迁移 load_strategy() 函数(保持原有逻辑)
- [x] 3.6 迁移 apply_color_scheme() 函数(使用 config 配色)
- [x] 3.7 实现 run_backtest() 函数(单股票回测)
- [x] 3.7.1 实现预热期日期计算逻辑
- [x] 3.7.2 实现数据加载和策略加载调用
- [x] 3.7.3 实现指标计算和数据截取
- [x] 3.7.4 实现 Backtest 执行
- [x] 3.7.5 实现图表生成(可选)
- [x] 3.7.6 实现 BacktestResult 对象构建和返回
- [x] 3.8 实现 run_batch_backtest() 函数(批量回测,串行)
- [x] 3.8.1 实现循环遍历股票代码
- [x] 3.8.2 实现为每个股票调用 run_backtest()
- [x] 3.8.3 实现为每个股票生成独立 HTML 文件
- [x] 3.8.4 实现结果列表收集和返回
- [x] 3.8.5 实现 tqdm 进度条显示(批量时)
- [x] 3.9 测试 run_backtest() 单股票回测
- [x] 3.10 测试 run_batch_backtest() 多股票回测
## 4. CLI 界面模块
- [x] 4.1 创建 backtest_command.py 文件
- [x] 4.2 在 backtest_command.py 中导入必要模块和 backtest_core
- [x] 4.3 实现 parse_arguments() 函数
- [x] 4.3.1 定义 --codes 多值参数nargs='+'
- [x] 4.3.2 定义 --output-dir 目录参数
- [x] 4.3.3 保持原有参数(--start-date, --end-date, --strategy-file, --cash, --commission, --warmup-days
- [x] 4.3.4 添加参数帮助文档和示例
- [x] 4.4 实现 format_single_result() 函数(详细格式输出)
- [x] 4.4.1 实现每个指标单独一行的格式化
- [x] 4.4.2 保持原有 print_stats() 的输出格式
- [x] 4.5 实现 format_batch_results() 函数(表格格式输出)
- [x] 4.5.1 实现使用 tabulate 生成表格
- [x] 4.5.2 定义表格列:股票代码、收益率%、胜率%、最大回撤%、交易次数、SQN
- [x] 4.5.3 实现表格数据填充(从 BacktestResult 对象提取)
- [x] 4.5.4 实现表格格式为 grid
- [x] 4.6 实现 main() 函数
- [x] 4.6.1 调用 parse_arguments() 解析参数
- [x] 4.6.2 调用 run_batch_backtest() 执行批量回测
- [x] 4.6.3 根据结果数量调用 format_single_result() 或 format_batch_results()
- [x] 4.6.4 实现图表保存提示(指定 --output-dir 时)
- [x] 4.6.5 实现错误捕获和友好错误信息输出
- [x] 4.6.6 实现退出状态码设置(成功 0失败非零
- [x] 4.7 添加 `if __name__ == "__main__": main()` 入口
- [x] 4.8 测试单股票回测命令行调用 (`uv run python backtest_command.py`)
- [x] 4.9 测试多股票回测命令行调用 (`uv run python backtest_command.py`)
- [x] 4.10 测试错误处理(参数缺失、文件不存在等)
## 5. 清理旧代码
- [x] 5.1 确认新功能完整(单股票、多股票、图表输出)
- [x] 5.2 确认错误处理正确(立即失败)
- [x] 5.3 删除 backtest.py 文件
- [x] 5.4 验证 git 状态(仅删除旧文件,无其他修改)
## 6. 文档更新
- [x] 6.1 更新 README.md如果存在
- [x] 6.1.1 说明新的命令行使用方式(`uv run python backtest_command.py`
- [x] 6.1.2 说明参数变化(--code 改为 --codes
- [x] 6.1.3 提供单股票和多股票示例
- [x] 6.1.4 说明 --output-dir 用法(多股票图表)
- [x] 6.2 创建 note_refactor.md可选记录重构说明
- [x] 6.2.1 说明文件结构变化
- [x] 6.2.2 说明接口变化
- [x] 6.2.3 提供迁移指南
## 7. 集成测试
- [x] 7.1 测试单个股票完整流程000001.SZ
- [x] 7.2 测试多个股票完整流程000001.SZ 600000.SH
- [x] 7.3 测试指定 --output-dir 生成图表
- [x] 7.4 测试不指定 --output-dir不生成图表
- [x] 7.5 测试错误情况(无效股票代码、不存在的策略文件等)
- [x] 7.6 验证进度条显示(多股票时)
- [x] 7.7 验证表格格式输出(多股票时)
- [x] 7.8 验证详细格式输出(单股票时)

View File

@@ -2,6 +2,6 @@ schema: spec-driven
Example:
context: |
使用 uv 工具进行 python 环境的管理和三方依赖的管理
严禁在主机环境直接运行 pip、pip3 安装依赖包,必须使用 uv 虚拟环境
使用 uv 工具进行 python 环境的管理和三方依赖的管理运行python命令的时候使用uv run python xxx
严禁在主机环境直接运行 pip、pip3 安装依赖包,必须使用 uv add xxx命令安装
项目面向中文开发者文档输出、日志输出、agent 交流时都要使用中文

View File

@@ -0,0 +1,312 @@
# batch-backtest Specification
## Purpose
TBD - created by archiving change refactor-backtest-separate-cli. Update Purpose after archive.
## Requirements
### Requirement: 多股票回测参数
系统 SHALL 支持通过命令行参数传入多个股票代码进行批量回测。
#### Scenario: 传入多个股票代码
- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ 600000.SH --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py`
- **THEN** 系统解析所有股票代码到列表 `['000001.SZ', '600000.SH']`
- **THEN** 系统按顺序依次执行每个股票的回测
- **THEN** 系统为每个股票生成独立的回测结果
#### Scenario: 传入单个股票代码
- **WHEN** 用户执行 `python backtest_command.py --codes 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategies/macd_strategy.py`
- **THEN** 系统解析为单个股票代码列表 `['000001.SZ']`
- **THEN** 系统执行单个股票回测
- **THEN** 系统输出详细格式的回测结果
#### Scenario: 缺少 --codes 参数
- **WHEN** 用户未提供 `--codes` 参数
- **THEN** 系统输出错误信息:"错误: 需要以下参数: --codes"
- **THEN** 系统退出并返回非零状态码
---
### Requirement: 批量回测执行
系统 SHALL 串行执行多个股票的回测,每次加载一个股票的数据并执行回测。
#### Scenario: 成功执行多个股票回测
- **WHEN** 用户传入 N 个股票代码
- **THEN** 系统循环 N 次,每次加载一个股票的数据
- **THEN** 系统每次执行完整的回测流程(数据加载、指标计算、回测执行)
- **THEN** 系统每次执行完成后生成 `BacktestResult` 对象
- **THEN** 系统返回包含 N 个 `BacktestResult` 的列表
#### Scenario: 每个股票独立预热期
- **WHEN** 系统执行第 i 个股票的回测
- **THEN** 系统使用 `start_date - warmup_days` 计算该股票的预热开始日期
- **THEN** 系统独立加载该股票的预热期数据
- **THEN** 不同股票的预热期互不影响
#### Scenario: 第一个股票回测失败
- **WHEN** 系统执行第一个股票回测时发生错误(数据库连接失败、策略加载失败等)
- **THEN** 系统捕获异常并输出错误信息
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码(立即失败策略)
#### Scenario: 中间股票回测失败
- **WHEN** 系统执行第 i 个股票回测时发生错误
- **THEN** 系统输出错误信息(包含股票代码)
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 资源管理
- **WHEN** 系统完成第 i 个股票的回测
- **THEN** 系统关闭该股票的数据库连接(`engine.dispose()`
- **THEN** 系统释放该股票的数据内存
- **THEN** 系统开始加载第 i+1 个股票的数据
---
### Requirement: 批量回测进度显示
系统 SHALL 使用 tqdm 显示批量回测的实时进度,提供用户反馈。
#### Scenario: 显示进度条
- **WHEN** 系统开始执行 N 个股票的批量回测
- **THEN** 系统显示进度条格式:`回测进度: 25%|█████▌ | 1/4 [00:30<01:30, 12.5s/it]`
- **THEN** 系统在完成每个股票回测后更新进度条
- **THEN** 进度条显示当前进度i/N、已用时间、预计剩余时间
- **THEN** 进度条在所有股票回测完成后消失
#### Scenario: 单股票回测不显示进度条
- **WHEN** 用户传入单个股票代码
- **THEN** 系统不显示 tqdm 进度条
- **THEN** 系统直接输出回测结果
#### Scenario: 进度条描述文本
- **WHEN** 系统显示批量回测进度
- **THEN** 进度条描述 SHALL 为 "回测进度"(中文)
- **THEN** 进度条显示已完成/总数(如 "1/4", "2/4"
---
### Requirement: 批量回测结果展示
系统 SHALL 使用 tabulate 将多个股票的回测结果格式化为表格,便于横向对比。
#### Scenario: 表格化输出多股票结果
- **WHEN** 用户传入多个股票代码且回测成功
- **THEN** 系统使用 tabulate 生成表格
- **THEN** 表格格式 SHALL 为 grid带边框
- **THEN** 表格列 SHALL 包含:股票代码、收益率%、胜率%、最大回撤%、交易次数、SQN
- **THEN** 系统在表格上方显示表头(中文列名)
- **THEN** 数值保留 2 位小数(交易次数为整数)
#### Scenario: 表格内容填充
- **WHEN** 系统格式化第 i 个股票的结果
- **THEN** 系统从 `BacktestResult` 对象提取字段
- **THEN** "股票代码" 列填充 `result.code`
- **THEN** "收益率%" 列填充 `result.return_pct`
- **THEN** "胜率%" 列填充 `result.win_rate`
- **THEN** "最大回撤%" 列填充 `result.max_drawdown`
- **THEN** "交易次数" 列填充 `result.trades`
- **THEN** "SQN" 列填充 `result.sqn`
#### Scenario: 单股票回测不使用表格
- **WHEN** 用户传入单个股票代码
- **THEN** 系统不使用 tabulate 生成表格
- **THEN** 系统使用详细格式输出(每个指标单独一行)
- **THEN** 系统保持原有 `print_stats()` 的输出格式
#### Scenario: 表格示例输出
- **WHEN** 用户传入 2 个股票代码
- **THEN** 系统输出格式 SHALL 为:
```
+-------------+-----------+--------+------------+----------+-------+
| 股票代码 | 收益率% | 胜率% | 最大回撤% | 交易次数 | SQN |
+-------------+-----------+--------+------------+----------+-------+
| 000001.SZ | 20.35 | 55.00 | -8.50 | 45 | 1.85 |
| 600000.SH | 15.00 | 48.00 | -12.30 | 38 | 1.42 |
+-------------+-----------+--------+------------+----------+-------+
```
---
### Requirement: 多股票图表输出
系统 SHALL 为每个股票生成独立的 HTML 图表文件,文件名格式为 `{code}.html`。
#### Scenario: 指定 --output-dir 参数
- **WHEN** 用户传入 `--output-dir output/`
- **THEN** 系统为每个股票生成 HTML 文件到 `output/{code}.html`
- **THEN** 文件名 SHALL 为股票代码,如 `000001.SZ.html`, `600000.SH.html`
- **THEN** 系统自动创建 `output/` 目录(`exist_ok=True`
- **THEN** 系统在完成后输出提示:"图表已保存到目录: output/" 后列出所有文件
#### Scenario: 未指定 --output-dir 参数
- **WHEN** 用户未传入 `--output-dir` 参数
- **THEN** 系统不为任何股票生成图表文件
- **THEN** 系统仅输出控制台统计信息
#### Scenario: 图表文件覆盖
- **WHEN** 系统再次执行相同的批量回测
- **THEN** 系统覆盖已存在的 HTML 文件
- **THEN** 系统不提示文件已存在
---
### Requirement: 结构化回测结果
系统 SHALL 返回标准化的 `BacktestResult` 对象,包含所有关键指标。
#### Scenario: BacktestResult 对象创建
- **WHEN** 系统完成单股票回测
- **THEN** 系统从 `stats` 对象提取指标到 `BacktestResult`
- **THEN** `BacktestResult.code` 设置为股票代码
- **THEN** `BacktestResult.start_date` 设置为回测开始日期
- **THEN** `BacktestResult.end_date` 设置为回测结束日期
- **THEN** `BacktestResult.equity_final` 设置为最终权益
- **THEN** `BacktestResult.equity_peak` 设置为峰值收益
- **THEN** `BacktestResult.return_pct` 设置为总收益率
- **THEN** `BacktestResult.buy_hold_return` 设置为买入持有收益率
- **THEN** `BacktestResult.return_annual` 设置为年化收益率
- **THEN** `BacktestResult.volatility_annual` 设置为年化波动率
- **THEN** `BacktestResult.max_drawdown` 设置为最大回撤
- **THEN** `BacktestResult.avg_drawdown` 设置为平均回撤
- **THEN** `BacktestResult.max_drawdown_duration` 设置为最大回撤持续时长
- **THEN** `BacktestResult.avg_drawdown_duration` 设置为平均回撤持续时长
- **THEN** `BacktestResult.sortino_ratio` 设置为索提诺比率
- **THEN** `BacktestResult.calmar_ratio` 设置为卡尔玛比率
- **THEN** `BacktestResult.trades` 设置为交易次数
- **THEN** `BacktestResult.win_rate` 设置为胜率
- **THEN** `BacktestResult.sqn` 设置为系统质量数
- **THEN** `BacktestResult.cash` 设置为初始资金
- **THEN** `BacktestResult.commission` 设置为手续费率
#### Scenario: BacktestResult 列表返回
- **WHEN** 系统完成批量回测
- **THEN** 系统返回 `List[BacktestResult]`
- **THEN** 列表顺序 SHALL 与输入股票代码顺序一致
- **THEN** 列表长度 SHALL 等于输入股票代码数量(成功时)
#### Scenario: BacktestResult 数据类型
- **WHEN** 系统创建 `BacktestResult` 对象
- **THEN** 数值字段 SHALL 为 float 类型(除 `trades`, `max_drawdown_duration` 为 int
- **THEN** 日期字段 SHALL 为 str 类型YYYY-MM-DD 格式)
- **THEN** 系统支持 `result.to_dict()` 方法dataclass 自动生成)
---
### Requirement: 可复用回测引擎接口
系统 SHALL 提供标准化的函数接口,供其他模块调用回测功能。
#### Scenario: run_backtest 函数调用
- **WHEN** 其他模块调用 `run_backtest(code, start_date, end_date, strategy_file, cash, commission, warmup_days, output_file)`
- **THEN** 函数接收股票代码、日期范围、策略文件、回测参数、输出文件路径
- **THEN** 函数执行完整回测流程(数据加载、策略加载、指标计算、回测执行)
- **THEN** 函数返回 `BacktestResult` 对象
- **THEN** 函数不打印任何输出(纯函数)
#### Scenario: run_batch_backtest 函数调用
- **WHEN** 其他模块调用 `run_batch_backtest(codes, start_date, end_date, strategy_file, cash, commission, warmup_days, output_dir)`
- **THEN** 函数接收股票代码列表、日期范围、策略文件、回测参数、输出目录
- **THEN** 函数串行执行每个股票的回测
- **THEN** 函数返回 `List[BacktestResult]`
- **THEN** 函数显示 tqdm 进度条(批量时)
#### Scenario: 函数参数默认值
- **WHEN** 调用者不指定可选参数
- **THEN** `cash` 默认为 100000
- **THEN** `commission` 默认为 0.002
- **THEN** `warmup_days` 默认为 365
- **THEN** `output_file` 默认为 None不生成图表
- **THEN** `output_dir` 默认为 None不生成图表
#### Scenario: 函数异常抛出
- **WHEN** `run_backtest` 或 `run_batch_backtest` 执行时发生错误
- **THEN** 函数 SHALL 抛出异常(不捕获)
- **THEN** 异常类型 SHALL 为 ValueError、TypeError 或原始异常
- **THEN** 异常信息 SHALL 包含具体错误原因
- **THEN** 调用者负责捕获和处理异常
---
### Requirement: 集中配置管理
系统 SHALL 在 config.py 中集中管理数据库配置、默认回测参数、图表配色。
#### Scenario: 数据库配置访问
- **WHEN** backtest_core.py 需要数据库连接参数
- **THEN** 模块从 config 导入 `DB_HOST`, `DB_PORT`, `DB_NAME`, `DB_USER`, `DB_PASSWORD`
- **THEN** 模块使用这些常量构建连接字符串
- **THEN** 模块不重复定义数据库配置
#### Scenario: 默认参数访问
- **WHEN** backtest_core.py 需要默认回测参数
- **THEN** 模块从 config 导入 `DEFAULT_CASH`, `DEFAULT_COMMISSION`, `DEFAULT_WARMUP_DAYS`
- **THEN** 模块使用这些常量作为函数默认值
- **THEN** 模块不重复定义默认参数
#### Scenario: 图表配色访问
- **WHEN** backtest_core.py 需要设置图表配色
- **THEN** 模块从 config 导入 `BULL_COLOR`, `BEAR_COLOR`
- **THEN** 模块使用这些颜色设置 `plotting.BULL_COLOR` 和 `plotting.BEAR_COLOR`
- **THEN** 模块不重复定义颜色配置
#### Scenario: 配置文件内容
- **WHEN** 查看 config.py 文件
- **THEN** 文件包含数据库配置DB_HOST, DB_PORT, DB_NAME, DB_USER, DB_PASSWORD
- **THEN** 文件包含默认回测参数DEFAULT_CASH, DEFAULT_COMMISSION, DEFAULT_WARMUP_DAYS
- **THEN** 文件包含图表配色BULL_COLOR, BEAR_COLOR
- **THEN** 所有配置使用明文常量(不使用环境变量)
---
### Requirement: 错误处理策略
系统 SHALL 在批量回测失败时立即停止执行,不继续处理后续股票。
#### Scenario: 数据加载失败
- **WHEN** 系统加载第 i 个股票数据时失败(数据库错误、数据不存在)
- **THEN** 系统捕获异常
- **THEN** 系统输出错误信息:"回测失败 [{code}]: {error}"
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 策略加载失败
- **WHEN** 系统加载策略文件时失败(文件不存在、接口不完整)
- **THEN** 系统捕获异常
- **THEN** 系统输出错误信息:"策略加载失败: {error}"
- **THEN** 系统停止执行所有股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 回测执行失败
- **WHEN** 系统执行第 i 个股票回测时失败(策略逻辑错误)
- **THEN** 系统捕获异常
- **THEN** 系统输出错误信息和完整堆栈跟踪
- **THEN** 系统停止执行后续股票的回测
- **THEN** 系统退出并返回非零状态码
#### Scenario: 图表生成失败
- **WHEN** 系统生成第 i 个股票图表时失败
- **THEN** 系统捕获异常
- **THEN** 系统输出警告:"图表生成失败 [{code}]: {error},但回测已完成"
- **THEN** 系统继续执行后续股票的回测
- **THEN** 系统在返回的 `BacktestResult` 中设置 `error` 字段(如果设计支持)
---
### Requirement: 依赖管理
系统 SHALL 在 pyproject.toml 中添加 tabulate 和 tqdm 依赖。
#### Scenario: 添加 tabulate 依赖
- **WHEN** 查看 pyproject.toml 文件
- **THEN** 文件包含 `tabulate` 依赖
- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本
- **THEN** 系统可以导入 `import tabulate` 无错误
#### Scenario: 添加 tqdm 依赖
- **WHEN** 查看 pyproject.toml 文件
- **THEN** 文件包含 `tqdm` 依赖
- **THEN** 依赖版本 SHALL 为兼容当前 Python 版本的版本
- **THEN** 系统可以导入 `from tqdm import tqdm` 无错误
#### Scenario: 依赖安装
- **WHEN** 用户运行 `uv sync` 或 `pip install -e .`
- **THEN** 系统自动安装 tabulate 和 tqdm
- **THEN** 系统显示依赖安装进度
- **THEN** 系统完成安装后可以正常使用回测工具
#### Scenario: 依赖缺失提示
- **WHEN** 系统导入 tabulate 或 tqdm 时失败
- **THEN** 系统输出友好错误信息:"缺少依赖: {package_name},请运行: uv add {package_name}"
- **THEN** 系统退出并返回非零状态码

View File

@@ -15,4 +15,6 @@ dependencies = [
"psycopg2-binary~=2.9.11",
"sqlalchemy>=2.0.46",
"ta-lib>=0.6.8",
"tabulate>=0.9.0",
"tqdm>=4.67.1",
]

View File

@@ -2,28 +2,27 @@
MACD 趋势跟踪策略
策略逻辑:
- 当 MACD 线上穿信号线时 (金叉),且价格 > EMA200 时,买入
- 当 MACD 线下穿信号线时 (死叉),或价格 < EMA200 时,卖出
- 当 MACD 线上穿信号线时 (金叉),且价格 > EMA 时,买入
- 当 MACD 线下穿信号线时 (死叉),或价格 < EMA 时,卖出
指标计算:
- MACD(10, 20, 9): 快线 10 日,慢线 20 日,信号线 9 日
- EMA200: 200 日指数移动平均线(趋势确认)
- EMA: 200 日指数移动平均线(趋势确认)
参数选择理由:
- 快线 10: 比标准 12 更敏感,适应 A 股较高波动性
- 慢线 20: 比标准 26 更快响应,同时保持趋势跟踪稳定性
- 信号线 9: 保持标准,避免信号过于频繁
- EMA200: 被广泛认可为牛熊分界线,避免逆势交易
- EMA: 被广泛认可为牛熊分界线,避免逆势交易
趋势过滤:
- EMA200 上方: 确认为上升趋势,允许开多仓
- EMA200 下方: 确认为下降趋势,不开多仓,强制平仓
- EMA 上方: 确认为上升趋势,允许开多仓
- EMA 下方: 确认为下降趋势,不开多仓,强制平仓
Author: Sisyphus
Date: 2025-01-27
"""
import pandas as pd
from backtesting import Strategy
from backtesting.lib import crossover
@@ -32,32 +31,30 @@ def calculate_indicators(data):
"""
计算策略所需的技术指标
使用 ta-lib 库计算 MACD 和 EMA200 指标
使用 ta-lib 库计算 MACD 和 EMA 指标
参数:
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
返回:
DataFrame, 添加了指标列:
- MACD_10_20_9: MACD 线 (DIF)
- MACDs_10_20_9: MACD 信号线 (DEA)
- MACDh_10_20_9: MACD 柱状图 (Histogram)
- EMA_200: 200 日指数移动平均线
- macd: MACD 线 (macd)
- signal: MACD 信号线 (DEA)
- hist: MACD 柱状图 (Histogram)
- ema: 日指数移动平均线
"""
data = data.copy()
# 计算 MACD 指标 (10, 20, 9)
# talib.MACD 返回三个值: (macd, macdsignal, macdhist)
macd, macdsignal, macdhist = talib.MACD(
data["Close"], fastperiod=10, slowperiod=20, signalperiod=9
)
macd, macdsignal, macdhist = talib.MACD(data["Close"], fastperiod=10, slowperiod=20, signalperiod=9)
data["MACD_10_20_9"] = macd
data["MACDs_10_20_9"] = macdsignal
data["MACDh_10_20_9"] = macdhist
data["macd"] = macd
data["signal"] = macdsignal
data["hist"] = macdhist
# 计算 EMA200 趋势线
data["EMA_200"] = talib.EMA(data["Close"], timeperiod=200)
# 计算 EMA 趋势线
data["ema"] = talib.SMA(data["Close"], timeperiod=120)
return data
@@ -76,7 +73,7 @@ class MacdTrendStrategy(Strategy):
"""
MACD 趋势跟踪策略
结合 MACD 金叉/死叉信号和 EMA200 趋势过滤
结合 MACD 金叉/死叉信号和 EMA 趋势过滤
参数:
fast_period: MACD 快线周期 (默认: 10)
@@ -95,13 +92,13 @@ class MacdTrendStrategy(Strategy):
注册指标到 backtesting 框架
"""
# 注册 MACD 线
self.macd = self.I(lambda x: x, self.data.MACD_10_20_9)
self.macd = self.I(lambda x: x, self.data.macd)
# 注册 MACD 信号线
self.macd_signal = self.I(lambda x: x, self.data.MACDs_10_20_9)
self.signal = self.I(lambda x: x, self.data.signal)
# 注册 EMA200 趋势线
self.ema200 = self.I(lambda x: x, self.data.EMA_200)
# 注册 EMA 趋势线
self.ema = self.I(lambda x: x, self.data.ema)
def next(self):
"""
@@ -109,25 +106,18 @@ class MacdTrendStrategy(Strategy):
买入条件:
- MACD 金叉 (MACD 线上穿信号线)
- 价格 > EMA200 (确认上升趋势)
- 价格 > EMA (确认上升趋势)
卖出条件:
- MACD 死叉 (MACD 线下穿信号线)
- 或价格 < EMA200 (趋势转向,强制平仓)
- 或价格 < EMA (趋势转向,强制平仓)
"""
# 买入条件: MACD 金叉 AND 价格 > EMA200
if (
crossover(self.macd, self.macd_signal)
and self.data.Close[-1] > self.ema200[-1]
):
self.position.close() # 先平掉现有仓位
# 买入条件: MACD 金叉 AND 价格 > EMA
if crossover(self.macd, self.signal) and self.data.Close[-1] > self.ema[-1]:
self.buy() # 开多仓
# 卖出条件: MACD 死叉 OR 价格 < EMA200
elif (
crossover(self.macd_signal, self.macd)
or self.data.Close[-1] < self.ema200[-1]
):
# 卖出条件: MACD 死叉 OR 价格 < EMA
elif self.position.size > 0 and (crossover(self.signal, self.macd) or self.data.Close[-1] < self.ema[-1]):
self.position.close() # 平掉多仓

View File

@@ -5,14 +5,13 @@ SMA 双均线交叉策略
- 当短期均线上穿长期均线时 (金叉),买入
- 当短期均线下穿长期均线时 (死叉),卖出
指标计算:
指标计算 (使用 ta-lib):
- SMA10: 10 日简单移动平均线
- SMA30: 30 日简单移动平均线
- SMA60: 60 日简单移动平均线
- SMA120: 120 日简单移动平均线
"""
import pandas as pd
from backtesting import Strategy
from backtesting.lib import crossover
@@ -21,19 +20,25 @@ def calculate_indicators(data):
"""
计算策略所需的技术指标
使用 ta-lib 库计算 SMA 指标
参数:
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
返回:
DataFrame, 添加了指标列
DataFrame, 添加了指标列:
- sma10: 10 日简单移动平均线
- sma30: 30 日简单移动平均线
- sma60: 60 日简单移动平均线
- sma120: 120 日简单移动平均线
"""
data = data.copy()
# 计算不同周期的移动平均线
data["sma10"] = data["Close"].rolling(window=10).mean()
data["sma30"] = data["Close"].rolling(window=30).mean()
data["sma60"] = data["Close"].rolling(window=60).mean()
data["sma120"] = data["Close"].rolling(window=120).mean()
data["sma10"] = talib.SMA(data["Close"], timeperiod=10)
data["sma30"] = talib.SMA(data["Close"], timeperiod=30)
data["sma60"] = talib.SMA(data["Close"], timeperiod=60)
data["sma120"] = talib.SMA(data["Close"], timeperiod=120)
return data
@@ -78,10 +83,12 @@ class SmaCross(Strategy):
"""
# 金叉:短期均线上穿长期均线
if crossover(self.data.sma10, self.data.sma30):
self.position.close() # 先平掉现有仓位
self.buy() # 开多仓
# 死叉:短期均线下穿长期均线
elif crossover(self.data.sma30, self.data.sma10):
self.position.close() # 先平掉现有仓位
self.sell() # 开空仓
elif self.position.size > 0 and crossover(self.data.sma30, self.data.sma10):
self.position.close() # 开空仓
# 导入 talib (必须在文件末尾,因为 calculate_indicators 函数中使用了 talib)
import talib

25
uv.lock generated
View File

@@ -912,6 +912,8 @@ dependencies = [
{ name = "psycopg2-binary" },
{ name = "sqlalchemy" },
{ name = "ta-lib" },
{ name = "tabulate" },
{ name = "tqdm" },
]
[package.metadata]
@@ -927,6 +929,8 @@ requires-dist = [
{ name = "psycopg2-binary", specifier = "~=2.9.11" },
{ name = "sqlalchemy", specifier = ">=2.0.46" },
{ name = "ta-lib", specifier = ">=0.6.8" },
{ name = "tabulate", specifier = ">=0.9.0" },
{ name = "tqdm", specifier = ">=4.67.1" },
]
[[package]]
@@ -1692,6 +1696,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0b/4c/d341020377f8b183405bdf3c5717fc2ca04a8d33b5c59b2348377ee459d9/ta_lib-0.6.8-cp314-cp314-win_arm64.whl", hash = "sha256:bfad1202fb1f9140e3810cc607058395f59032d9128cc0d716900c78bea5f337", size = 755896, upload-time = "2025-10-20T20:49:39.9Z" },
]
[[package]]
name = "tabulate"
version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090, upload-time = "2022-10-06T17:21:48.54Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252, upload-time = "2022-10-06T17:21:44.262Z" },
]
[[package]]
name = "terminado"
version = "0.18.1"
@@ -1737,6 +1750,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/50/49/8dc3fd90902f70084bd2cd059d576ddb4f8bb44c2c7c0e33a11422acb17e/tornado-6.5.4-cp39-abi3-win_arm64.whl", hash = "sha256:053e6e16701eb6cbe641f308f4c1a9541f91b6261991160391bfc342e8a551a1", size = 445910, upload-time = "2025-12-15T19:21:02.571Z" },
]
[[package]]
name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" },
]
[[package]]
name = "traitlets"
version = "5.14.3"