1
0
Files
leopard-analysis/backtest.py
2026-01-28 09:46:44 +08:00

299 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
量化回测主程序
使用方法:
python backtest.py --code 000001.SZ --start-date 2024-01-01 --end-date 2025-12-31 --strategy-file strategy.py
"""
import argparse
import importlib.util
import os
import sys
import pandas as pd
# 数据库配置(直接硬编码,开发环境)
DB_HOST = "81.71.3.24"
DB_PORT = 6785
DB_NAME = "leopard_dev"
DB_USER = "leopard"
DB_PASSWORD = "9NEzFzovnddf@PyEP?e*AYAWnCyd7UhYwQK$pJf>7?ccFiN^x4$eKEZ5~E<7<+~X"
def load_data_from_db(code, start_date, end_date):
"""
从数据库加载历史数据
参数:
code: 股票代码(如 '000001.SZ'
start_date: 开始日期(如 '2024-01-01'
end_date: 结束日期(如 '2025-12-31'
返回:
DataFrame, 包含列: [Open, High, Low, Close, Volume, factor]
"""
import sqlalchemy
import urllib.parse
# 构建连接字符串URL 编码密码中的特殊字符)
encoded_password = urllib.parse.quote_plus(DB_PASSWORD)
conn_str = (
f"postgresql://{DB_USER}:{encoded_password}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
)
engine = sqlalchemy.create_engine(conn_str)
try:
# 构建 SQL 查询
query = f"""
SELECT trade_date,
open * factor AS "Open",
close * factor AS "Close",
high * factor AS "High",
low * factor AS "Low",
volume AS "Volume",
COALESCE(factor, 1.0) AS factor
FROM leopard_daily daily
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
WHERE stock.code = '{code}'
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
AND '{end_date} 23:59:59'
ORDER BY daily.trade_date
"""
# 执行查询
df = pd.read_sql(query, engine)
if len(df) == 0:
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
# 潬换日期并设置为索引
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
df.set_index("trade_date", inplace=True)
return df
finally:
# 清理连接
engine.dispose()
def load_strategy(strategy_file):
"""
动态加载策略文件
参数:
strategy_file: 策略文件路径 (如 'strategy.py''strategies/macd.py')
返回:
(calculate_indicators, strategy_class) 元组
"""
# 获取模块名
module_name = strategy_file.replace(".py", "").replace("/", ".")
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# 接口验证
if not hasattr(module, "calculate_indicators"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
if not hasattr(module, "get_strategy"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
calculate_indicators = module.calculate_indicators
strategy_class = module.get_strategy()
# 验证 get_strategy 返回的是类
if not isinstance(strategy_class, type):
raise TypeError("get_strategy() 必须返回一个类")
# 验证策略类继承自 backtesting.Strategy
from backtesting import Strategy
if not issubclass(strategy_class, Strategy):
raise TypeError("策略类必须继承 backtesting.Strategy")
return calculate_indicators, strategy_class
def apply_color_scheme():
"""
应用颜色方案:红涨绿跌(中国股市风格)
"""
import backtesting._plotting as plotting
from bokeh.colors.named import tomato, lime
plotting.BULL_COLOR = tomato
plotting.BEAR_COLOR = lime
def parse_arguments():
"""
解析命令行参数
返回:
args: 命名空间对象
"""
parser = argparse.ArgumentParser(
description="量化回测工具", formatter_class=argparse.RawDescriptionHelpFormatter
)
# 必需参数
parser.add_argument(
"--code", type=str, required=True, help="股票代码 (如: 000001.SZ)"
)
parser.add_argument(
"--start-date", type=str, required=True, help="回测开始日期 (格式: YYYY-MM-DD)"
)
parser.add_argument(
"--end-date", type=str, required=True, help="回测结束日期 (格式: YYYY-MM-DD)"
)
parser.add_argument(
"--strategy-file",
type=str,
required=True,
help="策略文件路径 (如: strategy.py)",
)
# 可选参数
parser.add_argument(
"--cash", type=float, default=100000, help="初始资金 (默认: 100000)"
)
parser.add_argument(
"--commission", type=float, default=0.002, help="手续费率 (默认: 0.002)"
)
parser.add_argument(
"--output", type=str, default=None, help="HTML 输出文件路径 (可选)"
)
parser.add_argument(
"--warmup-days", type=int, default=365, help="预热天数 (默认: 365约一年"
)
return parser.parse_args()
def print_stats(stats):
"""
打印回测统计结果
参数:
stats: backtesting 库返回的统计对象
"""
print("=" * 60)
print("回测结果")
indicator_name_mapping = {
# 'Start': '回测开始时间',
# 'End': '回测结束时间',
# 'Duration': '回测持续时长',
# 'Exposure Time [%]': '持仓时间占比(%',
"Equity Final [$]": "最终收益",
"Equity Peak [$]": "峰值收益",
"Return [%]": "总收益率(%",
"Buy & Hold Return [%]": "买入并持有收益率(%",
"Return (Ann.) [%]": "年化收益率(%",
"Volatility (Ann.) [%]": "年化波动率(%",
# 'CAGR [%]': '复合年均增长率(%',
# 'Sharpe Ratio': '夏普比率',
"Sortino Ratio": "索提诺比率",
"Calmar Ratio": "卡尔玛比率",
# 'Alpha [%]': '阿尔法系数(%',
# 'Beta': '贝塔系数',
"Max. Drawdown [%]": "最大回撤(%",
"Avg. Drawdown [%]": "平均回撤(%",
"Max. Drawdown Duration": "最大回撤持续时长",
"Avg. Drawdown Duration": "平均回撤持续时长",
"# Trades": "总交易次数",
"Win Rate [%]": "胜率(%",
# 'Best Trade [%]': '最佳单笔交易收益率(%',
# 'Worst Trade [%]': '最差单笔交易收益率(%',
# 'Avg. Trade [%]': '平均单笔交易收益率(%',
# 'Max. Trade Duration': '单笔交易最长持有时长',
# 'Avg. Trade Duration': '单笔交易平均持有时长',
# 'Profit Factor': '盈利因子',
# 'Expectancy [%]': '期望收益(%',
"SQN": "系统质量数",
# 'Kelly Criterion': '凯利准则',
}
for k, v in stats.items():
if k in indicator_name_mapping:
cn_name = indicator_name_mapping.get(k, k)
if isinstance(v, (int, float)):
if "%" in cn_name or k in [
"Sharpe Ratio",
"Sortino Ratio",
"Calmar Ratio",
"Profit Factor",
]:
formatted_value = f"{v:.2f}"
elif "$" in cn_name:
formatted_value = f"{v:.2f}"
elif "次数" in cn_name:
formatted_value = f"{v:.0f}"
else:
formatted_value = f"{v:.4f}"
else:
formatted_value = str(v)
print(f"{cn_name}: {formatted_value}")
print("=" * 60)
def main():
"""
主函数:编排完整回测流程
"""
try:
# 解析参数
args = parse_arguments()
apply_color_scheme()
# 加载数据
print(f"加载股票数据: {args.code} ({args.start_date} ~ {args.end_date})")
data = load_data_from_db(args.code, args.start_date, args.end_date)
print(f"数据加载完成,共 {len(data)} 条记录")
# 截取预热数据
warmup_data = data.iloc[-args.warmup_days:]
print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}")
# 加载策略
calculate_indicators, strategy_class = load_strategy(args.strategy_file)
# 计算指标
warmup_data = calculate_indicators(warmup_data)
# 执行回测
from backtesting import Backtest
bt = Backtest(
warmup_data,
strategy_class,
cash=args.cash,
commission=args.commission,
finalize_trades=True,
)
stats = bt.run()
# 输出结果
print_stats(stats)
# 生成图表
if args.output:
os.makedirs(os.path.dirname(args.output), exist_ok=True)
bt.plot(filename=args.output, open_browser=False)
print(f"图表已保存到: {args.output}")
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()