#!/usr/bin/env python3 """ 量化回测主程序 使用方法: python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategy.py """ import argparse import sys import os import importlib.util import pandas as pd from datetime import datetime from backtesting import Backtest # 数据库配置(直接硬编码,开发环境) DB_HOST = "81.71.3.24" DB_PORT = 6785 DB_NAME = "leopard_dev" DB_USER = "leopard" DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X" def load_data_from_db(code, start_date, end_date): """ 从数据库加载历史数据 参数: code: 股票代码(如 '000001.SZ') start_date: 开始日期(如 '2024-01-01') end_date: 结束日期(如 '2025-12-31') 返回: DataFrame, 包含列: [Open, High, Low, Close, Volume, factor] """ import sqlalchemy import urllib.parse # 构建连接字符串(URL 编码密码中的特殊字符) encoded_password = urllib.parse.quote_plus(DB_PASSWORD) conn_str = ( f"postgresql://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}" ) engine = sqlalchemy.create_engine(conn_str) try: # 构建 SQL 查询 query = f""" SELECT trade_date, open * factor AS "Open", close * factor AS "Close", high * factor AS "High", low * factor AS "Low", volume AS "Volume", COALESCE(factor, 1.0) AS factor FROM leopard_daily daily LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id WHERE stock.code = '{code}' AND daily.trade_date BETWEEN '{start_date} 00:00:00' AND '{end_date} 23:59:59' ORDER BY daily.trade_date """ # 执行查询 df = pd.read_sql(query, engine) if len(df) == 0: raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据") # 潬换日期并设置为索引 df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d") df.set_index("trade_date", inplace=True) return df finally: # 清理连接 engine.dispose() def load_strategy(strategy_file): """ 动态加载策略文件 参数: strategy_file: 策略文件路径 (如 'strategy.py' 或 'strategies/macd.py') 返回: (calculate_indicators, strategy_class) 元组 """ # 获取模块名 module_name = strategy_file.replace(".py", "").replace("/", ".") spec = importlib.util.spec_from_file_location(module_name, strategy_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # 接口验证 if not hasattr(module, "calculate_indicators"): raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数") if not hasattr(module, "get_strategy"): raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数") calculate_indicators = module.calculate_indicators strategy_class = module.get_strategy() # 验证 get_strategy 返回的是类 if not isinstance(strategy_class, type): raise TypeError("get_strategy() 必须返回一个类") # 验证策略类继承自 backtesting.Strategy from backtesting import Strategy if not issubclass(strategy_class, Strategy): raise TypeError("策略类必须继承 backtesting.Strategy") return calculate_indicators, strategy_class def parse_arguments(): """ 解析命令行参数 返回: args: 命名空间对象 """ parser = argparse.ArgumentParser( description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter ) # 必需参数 parser.add_argument( "--code", type=str, required=True, help="股票代码 (如: 000001.SZ)" ) parser.add_argument( "--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)" ) parser.add_argument( "--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)" ) parser.add_argument( "--strategy-file", type=str, required=True, help="策略文件路径 (如: strategy.py)", ) # 可选参数 parser.add_argument( "--cash", type=float, default=100000, help="初始资金 (默认: 100000)" ) parser.add_argument( "--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)" ) parser.add_argument( "--output", type=str, default=None, help="HTML 输出文件路径 (可选)" ) parser.add_argument( "--warmup-days", type=int, default=365, help="预热天数 (默认: 365,约一年)" ) return parser.parse_args() def format_value(value, cn_name, key): """ 格式化数值显示 """ if isinstance(value, (int, float)): if "%" in cn_name or key in [ "Sharpe Ratio", "Sortino Ratio", "Calmar Ratio", "Profit Factor", ]: formatted_value = f"{value:.2f}" elif "$" in cn_name: formatted_value = f"{value:.2f}" elif "次数" in cn_name: formatted_value = f"{value:.0f}" else: formatted_value = f"{value:.4f}" else: formatted_value = str(value) return formatted_value def print_stats(stats): """ 打印回测统计结果 参数: stats: backtesting 库返回的统计对象 """ print("\n" + "=" * 60) print("回测结果") print("=" * 60) # 基本指标 metrics = [ ("Return (%)", "总收益率", "Return [%]"), ("Return", "总收益", "Return"), ("Sharpe Ratio", "夏普比率", "Sharpe Ratio"), ("Sortino Ratio", "索提诺比率", "Sortino Ratio"), ("Calmar Ratio", "卡尔玛比率", "Calmar Ratio"), ("Max Drawdown (%)", "最大回撤 (%)", "Max. Drawdown [%]"), ("Avg Drawdown (%)", "平均回撤 (%)", "Avg. Drawdown [%]"), ("Max Drawdown Duration", "最大回撤持续天数", "Max. Drawdown Duration"), ("Avg Drawdown Duration", "平均回撤持续天数", "Avg. Drawdown Duration"), ] for key, cn_name, en_name in metrics: try: value = getattr(stats, key, None) if value is not None: formatted = format_value(value, cn_name, key) print(f"{cn_name:20s}: {formatted}") except Exception: pass print() # 交易统计 trade_metrics = [ ("# Trades", "总交易次数", "# Trades"), ("Win Rate [%]", "胜率 (%)", "Win Rate [%]"), ("Best Trade", "最佳交易", "Best Trade"), ("Worst Trade", "最差交易", "Worst Trade"), ("Avg Trade", "平均交易", "Avg. Trade"), ("Avg Win Trade", "平均盈利交易", "Avg. Win Trade"), ("Avg Loss Trade", "平均亏损交易", "Avg. Loss Trade"), ("Profit Factor", "盈利因子", "Profit Factor"), ("Expectancy", "期望值", "Expectancy"), ] for key, cn_name, en_name in trade_metrics: try: value = getattr(stats, key, None) if value is not None: formatted = format_value(value, cn_name, key) print(f"{cn_name:20s}: {formatted}") except Exception: pass print("=" * 60 + "\n") def main(): """ 主函数:编排完整回测流程 """ try: # 解析参数 args = parse_arguments() # 加载数据 print(f"加载股票数据: {args.code} ({args.start_date} ~ {args.end_date})") data = load_data_from_db(args.code, args.start_date, args.end_date) print(f"数据加载完成,共 {len(data)} 条记录") # 截取预热数据 warmup_data = data.iloc[-args.warmup_days :] print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}") # 加载策略 print(f"加载策略: {args.strategy_file}") calculate_indicators, strategy_class = load_strategy(args.strategy_file) # 计算指标 print("计算指标...") warmup_data = calculate_indicators(warmup_data) print("指标计算完成") # 执行回测 print("开始回测...") from backtesting import Backtest bt = Backtest( warmup_data, strategy_class, cash=args.cash, commission=args.commission, finalize_trades=True, ) stats = bt.run() # 输出结果 print_stats(stats) # 生成图表 if args.output: print(f"\n生成图表: {args.output}") bt.plot(filename=args.output, open_browser=False) print(f"图表已保存到: {args.output}") print("\n回测完成!") except Exception as e: print(f"\n错误: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()