1
0
Files
leopard-analysis/backtest.py
2026-01-27 18:30:41 +08:00

311 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
量化回测主程序
使用方法:
python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategy.py
"""
import argparse
import sys
import os
import importlib.util
import pandas as pd
from datetime import datetime
from backtesting import Backtest
# 数据库配置(直接硬编码,开发环境)
DB_HOST = "81.71.3.24"
DB_PORT = 6785
DB_NAME = "leopard_dev"
DB_USER = "leopard"
DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X"
def load_data_from_db(code, start_date, end_date):
"""
从数据库加载历史数据
参数:
code: 股票代码(如 '000001.SZ'
start_date: 开始日期(如 '2024-01-01'
end_date: 结束日期(如 '2025-12-31'
返回:
DataFrame, 包含列: [Open, High, Low, Close, Volume, factor]
"""
import sqlalchemy
import urllib.parse
# 构建连接字符串URL 编码密码中的特殊字符)
encoded_password = urllib.parse.quote_plus(DB_PASSWORD)
conn_str = (
f"postgresql://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
)
engine = sqlalchemy.create_engine(conn_str)
try:
# 构建 SQL 查询
query = f"""
SELECT
trade_date,
open * factor AS "Open",
close * factor AS "Close",
high * factor AS "High",
low * factor AS "Low",
volume AS "Volume",
COALESCE(factor, 1.0) AS factor
FROM leopard_daily daily
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
WHERE stock.code = '{code}'
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
AND '{end_date} 23:59:59'
ORDER BY daily.trade_date
"""
# 执行查询
df = pd.read_sql(query, engine)
if len(df) == 0:
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
# 潬换日期并设置为索引
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
df.set_index("trade_date", inplace=True)
return df
finally:
# 清理连接
engine.dispose()
def load_strategy(strategy_file):
"""
动态加载策略文件
参数:
strategy_file: 策略文件路径 (如 'strategy.py''strategies/macd.py')
返回:
(calculate_indicators, strategy_class) 元组
"""
# 获取模块名
module_name = strategy_file.replace(".py", "").replace("/", ".")
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# 接口验证
if not hasattr(module, "calculate_indicators"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
if not hasattr(module, "get_strategy"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
calculate_indicators = module.calculate_indicators
strategy_class = module.get_strategy()
# 验证 get_strategy 返回的是类
if not isinstance(strategy_class, type):
raise TypeError("get_strategy() 必须返回一个类")
# 验证策略类继承自 backtesting.Strategy
from backtesting import Strategy
if not issubclass(strategy_class, Strategy):
raise TypeError("策略类必须继承 backtesting.Strategy")
return calculate_indicators, strategy_class
def parse_arguments():
"""
解析命令行参数
返回:
args: 命名空间对象
"""
parser = argparse.ArgumentParser(
description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter
)
# 必需参数
parser.add_argument(
"--code", type=str, required=True, help="股票代码 (如: 000001.SZ)"
)
parser.add_argument(
"--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)"
)
parser.add_argument(
"--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)"
)
parser.add_argument(
"--strategy-file",
type=str,
required=True,
help="策略文件路径 (如: strategy.py)",
)
# 可选参数
parser.add_argument(
"--cash", type=float, default=100000, help="初始资金 (默认: 100000)"
)
parser.add_argument(
"--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)"
)
parser.add_argument(
"--output", type=str, default=None, help="HTML 输出文件路径 (可选)"
)
parser.add_argument(
"--warmup-days", type=int, default=365, help="预热天数 (默认: 365约一年"
)
return parser.parse_args()
def format_value(value, cn_name, key):
"""
格式化数值显示
"""
if isinstance(value, (int, float)):
if "%" in cn_name or key in [
"Sharpe Ratio",
"Sortino Ratio",
"Calmar Ratio",
"Profit Factor",
]:
formatted_value = f"{value:.2f}"
elif "$" in cn_name:
formatted_value = f"{value:.2f}"
elif "次数" in cn_name:
formatted_value = f"{value:.0f}"
else:
formatted_value = f"{value:.4f}"
else:
formatted_value = str(value)
return formatted_value
def print_stats(stats):
"""
打印回测统计结果
参数:
stats: backtesting 库返回的统计对象
"""
print("\n" + "=" * 60)
print("回测结果")
print("=" * 60)
# 基本指标
metrics = [
("Return (%)", "总收益率", "Return [%]"),
("Return", "总收益", "Return"),
("Sharpe Ratio", "夏普比率", "Sharpe Ratio"),
("Sortino Ratio", "索提诺比率", "Sortino Ratio"),
("Calmar Ratio", "卡尔玛比率", "Calmar Ratio"),
("Max Drawdown (%)", "最大回撤 (%)", "Max. Drawdown [%]"),
("Avg Drawdown (%)", "平均回撤 (%)", "Avg. Drawdown [%]"),
("Max Drawdown Duration", "最大回撤持续天数", "Max. Drawdown Duration"),
("Avg Drawdown Duration", "平均回撤持续天数", "Avg. Drawdown Duration"),
]
for key, cn_name, en_name in metrics:
try:
value = getattr(stats, key, None)
if value is not None:
formatted = format_value(value, cn_name, key)
print(f"{cn_name:20s}: {formatted}")
except Exception:
pass
print()
# 交易统计
trade_metrics = [
("# Trades", "总交易次数", "# Trades"),
("Win Rate [%]", "胜率 (%)", "Win Rate [%]"),
("Best Trade", "最佳交易", "Best Trade"),
("Worst Trade", "最差交易", "Worst Trade"),
("Avg Trade", "平均交易", "Avg. Trade"),
("Avg Win Trade", "平均盈利交易", "Avg. Win Trade"),
("Avg Loss Trade", "平均亏损交易", "Avg. Loss Trade"),
("Profit Factor", "盈利因子", "Profit Factor"),
("Expectancy", "期望值", "Expectancy"),
]
for key, cn_name, en_name in trade_metrics:
try:
value = getattr(stats, key, None)
if value is not None:
formatted = format_value(value, cn_name, key)
print(f"{cn_name:20s}: {formatted}")
except Exception:
pass
print("=" * 60 + "\n")
def main():
"""
主函数:编排完整回测流程
"""
try:
# 解析参数
args = parse_arguments()
# 加载数据
print(f"加载股票数据: {args.code} ({args.start_date} ~ {args.end_date})")
data = load_data_from_db(args.code, args.start_date, args.end_date)
print(f"数据加载完成,共 {len(data)} 条记录")
# 截取预热数据
warmup_data = data.iloc[-args.warmup_days :]
print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}")
# 加载策略
print(f"加载策略: {args.strategy_file}")
calculate_indicators, strategy_class = load_strategy(args.strategy_file)
# 计算指标
print("计算指标...")
warmup_data = calculate_indicators(warmup_data)
print("指标计算完成")
# 执行回测
print("开始回测...")
from backtesting import Backtest
bt = Backtest(
warmup_data,
strategy_class,
cash=args.cash,
commission=args.commission,
finalize_trades=True,
)
stats = bt.run()
# 输出结果
print_stats(stats)
# 生成图表
if args.output:
print(f"\n生成图表: {args.output}")
bt.plot(filename=args.output, show=False)
print(f"图表已保存到: {args.output}")
print("\n回测完成!")
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()