299 lines
9.0 KiB
Python
299 lines
9.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
量化回测主程序
|
||
|
||
使用方法:
|
||
python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategy.py
|
||
"""
|
||
|
||
import argparse
|
||
import importlib.util
|
||
import os
|
||
import sys
|
||
|
||
import pandas as pd
|
||
|
||
# 数据库配置(直接硬编码,开发环境)
|
||
DB_HOST = "81.71.3.24"
|
||
DB_PORT = 6785
|
||
DB_NAME = "leopard_dev"
|
||
DB_USER = "leopard"
|
||
DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X"
|
||
|
||
|
||
def load_data_from_db(code, start_date, end_date):
|
||
"""
|
||
从数据库加载历史数据
|
||
|
||
参数:
|
||
code: 股票代码(如 '000001.SZ')
|
||
start_date: 开始日期(如 '2024-01-01')
|
||
end_date: 结束日期(如 '2025-12-31')
|
||
|
||
返回:
|
||
DataFrame, 包含列: [Open, High, Low, Close, Volume, factor]
|
||
"""
|
||
import sqlalchemy
|
||
import urllib.parse
|
||
|
||
# 构建连接字符串(URL 编码密码中的特殊字符)
|
||
encoded_password = urllib.parse.quote_plus(DB_PASSWORD)
|
||
conn_str = (
|
||
f"postgresql://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
||
)
|
||
engine = sqlalchemy.create_engine(conn_str)
|
||
|
||
try:
|
||
# 构建 SQL 查询
|
||
query = f"""
|
||
SELECT trade_date,
|
||
open * factor AS "Open",
|
||
close * factor AS "Close",
|
||
high * factor AS "High",
|
||
low * factor AS "Low",
|
||
volume AS "Volume",
|
||
COALESCE(factor, 1.0) AS factor
|
||
FROM leopard_daily daily
|
||
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
|
||
WHERE stock.code = '{code}'
|
||
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
|
||
AND '{end_date} 23:59:59'
|
||
ORDER BY daily.trade_date
|
||
"""
|
||
|
||
# 执行查询
|
||
df = pd.read_sql(query, engine)
|
||
|
||
if len(df) == 0:
|
||
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
|
||
|
||
# 潬换日期并设置为索引
|
||
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
|
||
df.set_index("trade_date", inplace=True)
|
||
|
||
return df
|
||
|
||
finally:
|
||
# 清理连接
|
||
engine.dispose()
|
||
|
||
|
||
def load_strategy(strategy_file):
|
||
"""
|
||
动态加载策略文件
|
||
|
||
参数:
|
||
strategy_file: 策略文件路径 (如 'strategy.py' 或 'strategies/macd.py')
|
||
|
||
返回:
|
||
(calculate_indicators, strategy_class) 元组
|
||
"""
|
||
# 获取模块名
|
||
module_name = strategy_file.replace(".py", "").replace("/", ".")
|
||
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
|
||
module = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(module)
|
||
|
||
# 接口验证
|
||
if not hasattr(module, "calculate_indicators"):
|
||
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
|
||
|
||
if not hasattr(module, "get_strategy"):
|
||
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
|
||
|
||
calculate_indicators = module.calculate_indicators
|
||
strategy_class = module.get_strategy()
|
||
|
||
# 验证 get_strategy 返回的是类
|
||
if not isinstance(strategy_class, type):
|
||
raise TypeError("get_strategy() 必须返回一个类")
|
||
|
||
# 验证策略类继承自 backtesting.Strategy
|
||
from backtesting import Strategy
|
||
|
||
if not issubclass(strategy_class, Strategy):
|
||
raise TypeError("策略类必须继承 backtesting.Strategy")
|
||
|
||
return calculate_indicators, strategy_class
|
||
|
||
|
||
def apply_color_scheme():
|
||
"""
|
||
应用颜色方案:红涨绿跌(中国股市风格)
|
||
"""
|
||
import backtesting._plotting as plotting
|
||
from bokeh.colors.named import tomato, lime
|
||
|
||
plotting.BULL_COLOR = tomato
|
||
plotting.BEAR_COLOR = lime
|
||
|
||
|
||
def parse_arguments():
|
||
"""
|
||
解析命令行参数
|
||
|
||
返回:
|
||
args: 命名空间对象
|
||
"""
|
||
parser = argparse.ArgumentParser(
|
||
description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter
|
||
)
|
||
|
||
# 必需参数
|
||
parser.add_argument(
|
||
"--code", type=str, required=True, help="股票代码 (如: 000001.SZ)"
|
||
)
|
||
parser.add_argument(
|
||
"--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)"
|
||
)
|
||
parser.add_argument(
|
||
"--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)"
|
||
)
|
||
parser.add_argument(
|
||
"--strategy-file",
|
||
type=str,
|
||
required=True,
|
||
help="策略文件路径 (如: strategy.py)",
|
||
)
|
||
|
||
# 可选参数
|
||
parser.add_argument(
|
||
"--cash", type=float, default=100000, help="初始资金 (默认: 100000)"
|
||
)
|
||
parser.add_argument(
|
||
"--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)"
|
||
)
|
||
parser.add_argument(
|
||
"--output", type=str, default=None, help="HTML 输出文件路径 (可选)"
|
||
)
|
||
parser.add_argument(
|
||
"--warmup-days", type=int, default=365, help="预热天数 (默认: 365,约一年)"
|
||
)
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
def print_stats(stats):
|
||
"""
|
||
打印回测统计结果
|
||
|
||
参数:
|
||
stats: backtesting 库返回的统计对象
|
||
"""
|
||
print("=" * 60)
|
||
print("回测结果")
|
||
|
||
indicator_name_mapping = {
|
||
# 'Start': '回测开始时间',
|
||
# 'End': '回测结束时间',
|
||
# 'Duration': '回测持续时长',
|
||
# 'Exposure Time [%]': '持仓时间占比(%)',
|
||
"Equity Final [$]": "最终收益",
|
||
"Equity Peak [$]": "峰值收益",
|
||
"Return [%]": "总收益率(%)",
|
||
"Buy & Hold Return [%]": "买入并持有收益率(%)",
|
||
"Return (Ann.) [%]": "年化收益率(%)",
|
||
"Volatility (Ann.) [%]": "年化波动率(%)",
|
||
# 'CAGR [%]': '复合年均增长率(%)',
|
||
# 'Sharpe Ratio': '夏普比率',
|
||
"Sortino Ratio": "索提诺比率",
|
||
"Calmar Ratio": "卡尔玛比率",
|
||
# 'Alpha [%]': '阿尔法系数(%)',
|
||
# 'Beta': '贝塔系数',
|
||
"Max. Drawdown [%]": "最大回撤(%)",
|
||
"Avg. Drawdown [%]": "平均回撤(%)",
|
||
"Max. Drawdown Duration": "最大回撤持续时长",
|
||
"Avg. Drawdown Duration": "平均回撤持续时长",
|
||
"# Trades": "总交易次数",
|
||
"Win Rate [%]": "胜率(%)",
|
||
# 'Best Trade [%]': '最佳单笔交易收益率(%)',
|
||
# 'Worst Trade [%]': '最差单笔交易收益率(%)',
|
||
# 'Avg. Trade [%]': '平均单笔交易收益率(%)',
|
||
# 'Max. Trade Duration': '单笔交易最长持有时长',
|
||
# 'Avg. Trade Duration': '单笔交易平均持有时长',
|
||
# 'Profit Factor': '盈利因子',
|
||
# 'Expectancy [%]': '期望收益(%)',
|
||
"SQN": "系统质量数",
|
||
# 'Kelly Criterion': '凯利准则',
|
||
}
|
||
for k, v in stats.items():
|
||
if k in indicator_name_mapping:
|
||
cn_name = indicator_name_mapping.get(k, k)
|
||
if isinstance(v, (int, float)):
|
||
if "%" in cn_name or k in [
|
||
"Sharpe Ratio",
|
||
"Sortino Ratio",
|
||
"Calmar Ratio",
|
||
"Profit Factor",
|
||
]:
|
||
formatted_value = f"{v:.2f}"
|
||
elif "$" in cn_name:
|
||
formatted_value = f"{v:.2f}"
|
||
elif "次数" in cn_name:
|
||
formatted_value = f"{v:.0f}"
|
||
else:
|
||
formatted_value = f"{v:.4f}"
|
||
else:
|
||
formatted_value = str(v)
|
||
print(f"{cn_name}: {formatted_value}")
|
||
|
||
print("=" * 60)
|
||
|
||
|
||
def main():
|
||
"""
|
||
主函数:编排完整回测流程
|
||
"""
|
||
try:
|
||
# 解析参数
|
||
args = parse_arguments()
|
||
|
||
apply_color_scheme()
|
||
|
||
# 加载数据
|
||
print(f"加载股票数据: {args.code} ({args.start_date} ~ {args.end_date})")
|
||
data = load_data_from_db(args.code, args.start_date, args.end_date)
|
||
print(f"数据加载完成,共 {len(data)} 条记录")
|
||
|
||
# 截取预热数据
|
||
warmup_data = data.iloc[-args.warmup_days:]
|
||
print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}")
|
||
|
||
# 加载策略
|
||
calculate_indicators, strategy_class = load_strategy(args.strategy_file)
|
||
|
||
# 计算指标
|
||
warmup_data = calculate_indicators(warmup_data)
|
||
|
||
# 执行回测
|
||
from backtesting import Backtest
|
||
|
||
bt = Backtest(
|
||
warmup_data,
|
||
strategy_class,
|
||
cash=args.cash,
|
||
commission=args.commission,
|
||
finalize_trades=True,
|
||
)
|
||
stats = bt.run()
|
||
|
||
# 输出结果
|
||
print_stats(stats)
|
||
|
||
# 生成图表
|
||
if args.output:
|
||
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
||
bt.plot(filename=args.output, open_browser=False)
|
||
print(f"图表已保存到: {args.output}")
|
||
|
||
except Exception as e:
|
||
print(f"\n错误: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|