修复代码问题
This commit is contained in:
112
backtest.py
112
backtest.py
@@ -7,12 +7,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from datetime import datetime
|
|
||||||
from backtesting import Backtest
|
|
||||||
|
|
||||||
# 数据库配置(直接硬编码,开发环境)
|
# 数据库配置(直接硬编码,开发环境)
|
||||||
DB_HOST = "81.71.3.24"
|
DB_HOST = "81.71.3.24"
|
||||||
@@ -47,20 +46,19 @@ def load_data_from_db(code, start_date, end_date):
|
|||||||
try:
|
try:
|
||||||
# 构建 SQL 查询
|
# 构建 SQL 查询
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT
|
SELECT trade_date,
|
||||||
trade_date,
|
open * factor AS "Open",
|
||||||
open * factor AS "Open",
|
close * factor AS "Close",
|
||||||
close * factor AS "Close",
|
high * factor AS "High",
|
||||||
high * factor AS "High",
|
low * factor AS "Low",
|
||||||
low * factor AS "Low",
|
volume AS "Volume",
|
||||||
volume AS "Volume",
|
COALESCE(factor, 1.0) AS factor
|
||||||
COALESCE(factor, 1.0) AS factor
|
FROM leopard_daily daily
|
||||||
FROM leopard_daily daily
|
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
|
||||||
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
|
WHERE stock.code = '{code}'
|
||||||
WHERE stock.code = '{code}'
|
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
|
||||||
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
|
AND '{end_date} 23:59:59'
|
||||||
AND '{end_date} 23:59:59'
|
ORDER BY daily.trade_date
|
||||||
ORDER BY daily.trade_date
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
@@ -175,30 +173,6 @@ def parse_arguments():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def format_value(value, cn_name, key):
|
|
||||||
"""
|
|
||||||
格式化数值显示
|
|
||||||
"""
|
|
||||||
if isinstance(value, (int, float)):
|
|
||||||
if "%" in cn_name or key in [
|
|
||||||
"Sharpe Ratio",
|
|
||||||
"Sortino Ratio",
|
|
||||||
"Calmar Ratio",
|
|
||||||
"Profit Factor",
|
|
||||||
]:
|
|
||||||
formatted_value = f"{value:.2f}"
|
|
||||||
elif "$" in cn_name:
|
|
||||||
formatted_value = f"{value:.2f}"
|
|
||||||
elif "次数" in cn_name:
|
|
||||||
formatted_value = f"{value:.0f}"
|
|
||||||
else:
|
|
||||||
formatted_value = f"{value:.4f}"
|
|
||||||
else:
|
|
||||||
formatted_value = str(value)
|
|
||||||
|
|
||||||
return formatted_value
|
|
||||||
|
|
||||||
|
|
||||||
def print_stats(stats):
|
def print_stats(stats):
|
||||||
"""
|
"""
|
||||||
打印回测统计结果
|
打印回测统计结果
|
||||||
@@ -206,33 +180,32 @@ def print_stats(stats):
|
|||||||
参数:
|
参数:
|
||||||
stats: backtesting 库返回的统计对象
|
stats: backtesting 库返回的统计对象
|
||||||
"""
|
"""
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("回测结果")
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
print("回测结果")
|
||||||
|
|
||||||
indicator_name_mapping = {
|
indicator_name_mapping = {
|
||||||
# 'Start': '回测开始时间',
|
# 'Start': '回测开始时间',
|
||||||
# 'End': '回测结束时间',
|
# 'End': '回测结束时间',
|
||||||
# 'Duration': '回测持续时长',
|
# 'Duration': '回测持续时长',
|
||||||
# 'Exposure Time [%]': '持仓时间占比(%)',
|
# 'Exposure Time [%]': '持仓时间占比(%)',
|
||||||
'Equity Final [$]': '最终收益',
|
"Equity Final [$]": "最终收益",
|
||||||
'Equity Peak [$]': '峰值收益',
|
"Equity Peak [$]": "峰值收益",
|
||||||
'Return [%]': '总收益率(%)',
|
"Return [%]": "总收益率(%)",
|
||||||
'Buy & Hold Return [%]': '买入并持有收益率(%)',
|
"Buy & Hold Return [%]": "买入并持有收益率(%)",
|
||||||
'Return (Ann.) [%]': '年化收益率(%)',
|
"Return (Ann.) [%]": "年化收益率(%)",
|
||||||
'Volatility (Ann.) [%]': '年化波动率(%)',
|
"Volatility (Ann.) [%]": "年化波动率(%)",
|
||||||
# 'CAGR [%]': '复合年均增长率(%)',
|
# 'CAGR [%]': '复合年均增长率(%)',
|
||||||
# 'Sharpe Ratio': '夏普比率',
|
# 'Sharpe Ratio': '夏普比率',
|
||||||
'Sortino Ratio': '索提诺比率',
|
"Sortino Ratio": "索提诺比率",
|
||||||
'Calmar Ratio': '卡尔玛比率',
|
"Calmar Ratio": "卡尔玛比率",
|
||||||
# 'Alpha [%]': '阿尔法系数(%)',
|
# 'Alpha [%]': '阿尔法系数(%)',
|
||||||
# 'Beta': '贝塔系数',
|
# 'Beta': '贝塔系数',
|
||||||
'Max. Drawdown [%]': '最大回撤(%)',
|
"Max. Drawdown [%]": "最大回撤(%)",
|
||||||
'Avg. Drawdown [%]': '平均回撤(%)',
|
"Avg. Drawdown [%]": "平均回撤(%)",
|
||||||
'Max. Drawdown Duration': '最大回撤持续时长',
|
"Max. Drawdown Duration": "最大回撤持续时长",
|
||||||
'Avg. Drawdown Duration': '平均回撤持续时长',
|
"Avg. Drawdown Duration": "平均回撤持续时长",
|
||||||
'# Trades': '总交易次数',
|
"# Trades": "总交易次数",
|
||||||
'Win Rate [%]': '胜率(%)',
|
"Win Rate [%]": "胜率(%)",
|
||||||
# 'Best Trade [%]': '最佳单笔交易收益率(%)',
|
# 'Best Trade [%]': '最佳单笔交易收益率(%)',
|
||||||
# 'Worst Trade [%]': '最差单笔交易收益率(%)',
|
# 'Worst Trade [%]': '最差单笔交易收益率(%)',
|
||||||
# 'Avg. Trade [%]': '平均单笔交易收益率(%)',
|
# 'Avg. Trade [%]': '平均单笔交易收益率(%)',
|
||||||
@@ -240,14 +213,19 @@ def print_stats(stats):
|
|||||||
# 'Avg. Trade Duration': '单笔交易平均持有时长',
|
# 'Avg. Trade Duration': '单笔交易平均持有时长',
|
||||||
# 'Profit Factor': '盈利因子',
|
# 'Profit Factor': '盈利因子',
|
||||||
# 'Expectancy [%]': '期望收益(%)',
|
# 'Expectancy [%]': '期望收益(%)',
|
||||||
'SQN': '系统质量数',
|
"SQN": "系统质量数",
|
||||||
# 'Kelly Criterion': '凯利准则',
|
# 'Kelly Criterion': '凯利准则',
|
||||||
}
|
}
|
||||||
for k, v in stats.items():
|
for k, v in stats.items():
|
||||||
if k in indicator_name_mapping:
|
if k in indicator_name_mapping:
|
||||||
cn_name = indicator_name_mapping.get(k, k)
|
cn_name = indicator_name_mapping.get(k, k)
|
||||||
if isinstance(v, (int, float)):
|
if isinstance(v, (int, float)):
|
||||||
if "%" in cn_name or k in ['Sharpe Ratio', 'Sortino Ratio', 'Calmar Ratio', 'Profit Factor']:
|
if "%" in cn_name or k in [
|
||||||
|
"Sharpe Ratio",
|
||||||
|
"Sortino Ratio",
|
||||||
|
"Calmar Ratio",
|
||||||
|
"Profit Factor",
|
||||||
|
]:
|
||||||
formatted_value = f"{v:.2f}"
|
formatted_value = f"{v:.2f}"
|
||||||
elif "$" in cn_name:
|
elif "$" in cn_name:
|
||||||
formatted_value = f"{v:.2f}"
|
formatted_value = f"{v:.2f}"
|
||||||
@@ -257,9 +235,9 @@ def print_stats(stats):
|
|||||||
formatted_value = f"{v:.4f}"
|
formatted_value = f"{v:.4f}"
|
||||||
else:
|
else:
|
||||||
formatted_value = str(v)
|
formatted_value = str(v)
|
||||||
print(f'{cn_name}: {formatted_value}')
|
print(f"{cn_name}: {formatted_value}")
|
||||||
|
|
||||||
print("=" * 60 + "\n")
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -278,20 +256,16 @@ def main():
|
|||||||
print(f"数据加载完成,共 {len(data)} 条记录")
|
print(f"数据加载完成,共 {len(data)} 条记录")
|
||||||
|
|
||||||
# 截取预热数据
|
# 截取预热数据
|
||||||
warmup_data = data.iloc[-args.warmup_days :]
|
warmup_data = data.iloc[-args.warmup_days:]
|
||||||
print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}")
|
print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}")
|
||||||
|
|
||||||
# 加载策略
|
# 加载策略
|
||||||
print(f"加载策略: {args.strategy_file}")
|
|
||||||
calculate_indicators, strategy_class = load_strategy(args.strategy_file)
|
calculate_indicators, strategy_class = load_strategy(args.strategy_file)
|
||||||
|
|
||||||
# 计算指标
|
# 计算指标
|
||||||
print("计算指标...")
|
|
||||||
warmup_data = calculate_indicators(warmup_data)
|
warmup_data = calculate_indicators(warmup_data)
|
||||||
print("指标计算完成")
|
|
||||||
|
|
||||||
# 执行回测
|
# 执行回测
|
||||||
print("开始回测...")
|
|
||||||
from backtesting import Backtest
|
from backtesting import Backtest
|
||||||
|
|
||||||
bt = Backtest(
|
bt = Backtest(
|
||||||
@@ -308,12 +282,10 @@ def main():
|
|||||||
|
|
||||||
# 生成图表
|
# 生成图表
|
||||||
if args.output:
|
if args.output:
|
||||||
print(f"\n生成图表: {args.output}")
|
os.makedirs(os.path.dirname(args.output), exist_ok=True)
|
||||||
bt.plot(filename=args.output, open_browser=False)
|
bt.plot(filename=args.output, open_browser=False)
|
||||||
print(f"图表已保存到: {args.output}")
|
print(f"图表已保存到: {args.output}")
|
||||||
|
|
||||||
print("\n回测完成!")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n错误: {e}")
|
print(f"\n错误: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|||||||
@@ -2,28 +2,27 @@
|
|||||||
MACD 趋势跟踪策略
|
MACD 趋势跟踪策略
|
||||||
|
|
||||||
策略逻辑:
|
策略逻辑:
|
||||||
- 当 MACD 线上穿信号线时 (金叉),且价格 > EMA200 时,买入
|
- 当 MACD 线上穿信号线时 (金叉),且价格 > EMA 时,买入
|
||||||
- 当 MACD 线下穿信号线时 (死叉),或价格 < EMA200 时,卖出
|
- 当 MACD 线下穿信号线时 (死叉),或价格 < EMA 时,卖出
|
||||||
|
|
||||||
指标计算:
|
指标计算:
|
||||||
- MACD(10, 20, 9): 快线 10 日,慢线 20 日,信号线 9 日
|
- MACD(10, 20, 9): 快线 10 日,慢线 20 日,信号线 9 日
|
||||||
- EMA200: 200 日指数移动平均线(趋势确认)
|
- EMA: 200 日指数移动平均线(趋势确认)
|
||||||
|
|
||||||
参数选择理由:
|
参数选择理由:
|
||||||
- 快线 10: 比标准 12 更敏感,适应 A 股较高波动性
|
- 快线 10: 比标准 12 更敏感,适应 A 股较高波动性
|
||||||
- 慢线 20: 比标准 26 更快响应,同时保持趋势跟踪稳定性
|
- 慢线 20: 比标准 26 更快响应,同时保持趋势跟踪稳定性
|
||||||
- 信号线 9: 保持标准,避免信号过于频繁
|
- 信号线 9: 保持标准,避免信号过于频繁
|
||||||
- EMA200: 被广泛认可为牛熊分界线,避免逆势交易
|
- EMA: 被广泛认可为牛熊分界线,避免逆势交易
|
||||||
|
|
||||||
趋势过滤:
|
趋势过滤:
|
||||||
- EMA200 上方: 确认为上升趋势,允许开多仓
|
- EMA 上方: 确认为上升趋势,允许开多仓
|
||||||
- EMA200 下方: 确认为下降趋势,不开多仓,强制平仓
|
- EMA 下方: 确认为下降趋势,不开多仓,强制平仓
|
||||||
|
|
||||||
Author: Sisyphus
|
Author: Sisyphus
|
||||||
Date: 2025-01-27
|
Date: 2025-01-27
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from backtesting import Strategy
|
from backtesting import Strategy
|
||||||
from backtesting.lib import crossover
|
from backtesting.lib import crossover
|
||||||
|
|
||||||
@@ -32,32 +31,30 @@ def calculate_indicators(data):
|
|||||||
"""
|
"""
|
||||||
计算策略所需的技术指标
|
计算策略所需的技术指标
|
||||||
|
|
||||||
使用 ta-lib 库计算 MACD 和 EMA200 指标
|
使用 ta-lib 库计算 MACD 和 EMA 指标
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
|
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
DataFrame, 添加了指标列:
|
DataFrame, 添加了指标列:
|
||||||
- MACD_10_20_9: MACD 线 (DIF)
|
- dif: MACD 线 (dif)
|
||||||
- MACDs_10_20_9: MACD 信号线 (DEA)
|
- signal: MACD 信号线 (DEA)
|
||||||
- MACDh_10_20_9: MACD 柱状图 (Histogram)
|
- hist: MACD 柱状图 (Histogram)
|
||||||
- EMA_200: 200 日指数移动平均线
|
- ema: 日指数移动平均线
|
||||||
"""
|
"""
|
||||||
data = data.copy()
|
data = data.copy()
|
||||||
|
|
||||||
# 计算 MACD 指标 (10, 20, 9)
|
# 计算 MACD 指标 (10, 20, 9)
|
||||||
# talib.MACD 返回三个值: (macd, macdsignal, macdhist)
|
# talib.MACD 返回三个值: (macd, macdsignal, macdhist)
|
||||||
macd, macdsignal, macdhist = talib.MACD(
|
macd, macdsignal, macdhist = talib.MACD(data["Close"], fastperiod=10, slowperiod=20, signalperiod=9)
|
||||||
data["Close"], fastperiod=10, slowperiod=20, signalperiod=9
|
|
||||||
)
|
|
||||||
|
|
||||||
data["MACD_10_20_9"] = macd
|
data["dif"] = macd
|
||||||
data["MACDs_10_20_9"] = macdsignal
|
data["signal"] = macdsignal
|
||||||
data["MACDh_10_20_9"] = macdhist
|
data["hist"] = macdhist
|
||||||
|
|
||||||
# 计算 EMA200 趋势线
|
# 计算 EMA 趋势线
|
||||||
data["EMA_200"] = talib.EMA(data["Close"], timeperiod=200)
|
data["ema"] = talib.SMA(data["Close"], timeperiod=120)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -76,7 +73,7 @@ class MacdTrendStrategy(Strategy):
|
|||||||
"""
|
"""
|
||||||
MACD 趋势跟踪策略
|
MACD 趋势跟踪策略
|
||||||
|
|
||||||
结合 MACD 金叉/死叉信号和 EMA200 趋势过滤
|
结合 MACD 金叉/死叉信号和 EMA 趋势过滤
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
fast_period: MACD 快线周期 (默认: 10)
|
fast_period: MACD 快线周期 (默认: 10)
|
||||||
@@ -95,13 +92,13 @@ class MacdTrendStrategy(Strategy):
|
|||||||
注册指标到 backtesting 框架
|
注册指标到 backtesting 框架
|
||||||
"""
|
"""
|
||||||
# 注册 MACD 线
|
# 注册 MACD 线
|
||||||
self.macd = self.I(lambda x: x, self.data.MACD_10_20_9)
|
self.macd = self.I(lambda x: x, self.data.dif)
|
||||||
|
|
||||||
# 注册 MACD 信号线
|
# 注册 MACD 信号线
|
||||||
self.macd_signal = self.I(lambda x: x, self.data.MACDs_10_20_9)
|
self.macd_signal = self.I(lambda x: x, self.data.signal)
|
||||||
|
|
||||||
# 注册 EMA200 趋势线
|
# 注册 EMA 趋势线
|
||||||
self.ema200 = self.I(lambda x: x, self.data.EMA_200)
|
self.ema = self.I(lambda x: x, self.data.ema)
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
"""
|
"""
|
||||||
@@ -109,25 +106,19 @@ class MacdTrendStrategy(Strategy):
|
|||||||
|
|
||||||
买入条件:
|
买入条件:
|
||||||
- MACD 金叉 (MACD 线上穿信号线)
|
- MACD 金叉 (MACD 线上穿信号线)
|
||||||
- 价格 > EMA200 (确认上升趋势)
|
- 价格 > EMA (确认上升趋势)
|
||||||
|
|
||||||
卖出条件:
|
卖出条件:
|
||||||
- MACD 死叉 (MACD 线下穿信号线)
|
- MACD 死叉 (MACD 线下穿信号线)
|
||||||
- 或价格 < EMA200 (趋势转向,强制平仓)
|
- 或价格 < EMA (趋势转向,强制平仓)
|
||||||
"""
|
"""
|
||||||
# 买入条件: MACD 金叉 AND 价格 > EMA200
|
# 买入条件: MACD 金叉 AND 价格 > EMA
|
||||||
if (
|
if crossover(self.macd, self.macd_signal) and self.data.Close[-1] > self.ema[-1]:
|
||||||
crossover(self.macd, self.macd_signal)
|
|
||||||
and self.data.Close[-1] > self.ema200[-1]
|
|
||||||
):
|
|
||||||
self.position.close() # 先平掉现有仓位
|
self.position.close() # 先平掉现有仓位
|
||||||
self.buy() # 开多仓
|
self.buy() # 开多仓
|
||||||
|
|
||||||
# 卖出条件: MACD 死叉 OR 价格 < EMA200
|
# 卖出条件: MACD 死叉 OR 价格 < EMA
|
||||||
elif (
|
elif crossover(self.macd_signal, self.macd) or self.data.Close[-1] < self.ema[-1]:
|
||||||
crossover(self.macd_signal, self.macd)
|
|
||||||
or self.data.Close[-1] < self.ema200[-1]
|
|
||||||
):
|
|
||||||
self.position.close() # 平掉多仓
|
self.position.close() # 平掉多仓
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ SMA 双均线交叉策略
|
|||||||
- 当短期均线上穿长期均线时 (金叉),买入
|
- 当短期均线上穿长期均线时 (金叉),买入
|
||||||
- 当短期均线下穿长期均线时 (死叉),卖出
|
- 当短期均线下穿长期均线时 (死叉),卖出
|
||||||
|
|
||||||
指标计算:
|
指标计算 (使用 ta-lib):
|
||||||
- SMA10: 10 日简单移动平均线
|
- SMA10: 10 日简单移动平均线
|
||||||
- SMA30: 30 日简单移动平均线
|
- SMA30: 30 日简单移动平均线
|
||||||
- SMA60: 60 日简单移动平均线
|
- SMA60: 60 日简单移动平均线
|
||||||
@@ -21,19 +21,25 @@ def calculate_indicators(data):
|
|||||||
"""
|
"""
|
||||||
计算策略所需的技术指标
|
计算策略所需的技术指标
|
||||||
|
|
||||||
|
使用 ta-lib 库计算 SMA 指标
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
|
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
DataFrame, 添加了指标列
|
DataFrame, 添加了指标列:
|
||||||
|
- sma10: 10 日简单移动平均线
|
||||||
|
- sma30: 30 日简单移动平均线
|
||||||
|
- sma60: 60 日简单移动平均线
|
||||||
|
- sma120: 120 日简单移动平均线
|
||||||
"""
|
"""
|
||||||
data = data.copy()
|
data = data.copy()
|
||||||
|
|
||||||
# 计算不同周期的移动平均线
|
# 计算不同周期的移动平均线
|
||||||
data["sma10"] = data["Close"].rolling(window=10).mean()
|
data["sma10"] = talib.SMA(data["Close"], timeperiod=10)
|
||||||
data["sma30"] = data["Close"].rolling(window=30).mean()
|
data["sma30"] = talib.SMA(data["Close"], timeperiod=30)
|
||||||
data["sma60"] = data["Close"].rolling(window=60).mean()
|
data["sma60"] = talib.SMA(data["Close"], timeperiod=60)
|
||||||
data["sma120"] = data["Close"].rolling(window=120).mean()
|
data["sma120"] = talib.SMA(data["Close"], timeperiod=120)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -85,3 +91,7 @@ class SmaCross(Strategy):
|
|||||||
elif crossover(self.data.sma30, self.data.sma10):
|
elif crossover(self.data.sma30, self.data.sma10):
|
||||||
self.position.close() # 先平掉现有仓位
|
self.position.close() # 先平掉现有仓位
|
||||||
self.sell() # 开空仓
|
self.sell() # 开空仓
|
||||||
|
|
||||||
|
|
||||||
|
# 导入 talib (必须在文件末尾,因为 calculate_indicators 函数中使用了 talib)
|
||||||
|
import talib
|
||||||
|
|||||||
Reference in New Issue
Block a user