1
0

修复代码问题

This commit is contained in:
2026-01-28 09:46:44 +08:00
parent 9a46bd7e4c
commit 5cc140259e
3 changed files with 86 additions and 113 deletions

View File

@@ -7,12 +7,11 @@
""" """
import argparse import argparse
import sys
import os
import importlib.util import importlib.util
import os
import sys
import pandas as pd import pandas as pd
from datetime import datetime
from backtesting import Backtest
# 数据库配置(直接硬编码,开发环境) # 数据库配置(直接硬编码,开发环境)
DB_HOST = "81.71.3.24" DB_HOST = "81.71.3.24"
@@ -47,20 +46,19 @@ def load_data_from_db(code, start_date, end_date):
try: try:
# 构建 SQL 查询 # 构建 SQL 查询
query = f""" query = f"""
SELECT SELECT trade_date,
trade_date, open * factor AS "Open",
open * factor AS "Open", close * factor AS "Close",
close * factor AS "Close", high * factor AS "High",
high * factor AS "High", low * factor AS "Low",
low * factor AS "Low", volume AS "Volume",
volume AS "Volume", COALESCE(factor, 1.0) AS factor
COALESCE(factor, 1.0) AS factor FROM leopard_daily daily
FROM leopard_daily daily LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id WHERE stock.code = '{code}'
WHERE stock.code = '{code}' AND daily.trade_date BETWEEN '{start_date} 00:00:00'
AND daily.trade_date BETWEEN '{start_date} 00:00:00' AND '{end_date} 23:59:59'
AND '{end_date} 23:59:59' ORDER BY daily.trade_date
ORDER BY daily.trade_date
""" """
# 执行查询 # 执行查询
@@ -175,30 +173,6 @@ def parse_arguments():
return parser.parse_args() return parser.parse_args()
def format_value(value, cn_name, key):
"""
格式化数值显示
"""
if isinstance(value, (int, float)):
if "%" in cn_name or key in [
"Sharpe Ratio",
"Sortino Ratio",
"Calmar Ratio",
"Profit Factor",
]:
formatted_value = f"{value:.2f}"
elif "$" in cn_name:
formatted_value = f"{value:.2f}"
elif "次数" in cn_name:
formatted_value = f"{value:.0f}"
else:
formatted_value = f"{value:.4f}"
else:
formatted_value = str(value)
return formatted_value
def print_stats(stats): def print_stats(stats):
""" """
打印回测统计结果 打印回测统计结果
@@ -206,33 +180,32 @@ def print_stats(stats):
参数: 参数:
stats: backtesting 库返回的统计对象 stats: backtesting 库返回的统计对象
""" """
print("\n" + "=" * 60)
print("回测结果")
print("=" * 60) print("=" * 60)
print("回测结果")
indicator_name_mapping = { indicator_name_mapping = {
# 'Start': '回测开始时间', # 'Start': '回测开始时间',
# 'End': '回测结束时间', # 'End': '回测结束时间',
# 'Duration': '回测持续时长', # 'Duration': '回测持续时长',
# 'Exposure Time [%]': '持仓时间占比(%', # 'Exposure Time [%]': '持仓时间占比(%',
'Equity Final [$]': '最终收益', "Equity Final [$]": "最终收益",
'Equity Peak [$]': '峰值收益', "Equity Peak [$]": "峰值收益",
'Return [%]': '总收益率(%', "Return [%]": "总收益率(%",
'Buy & Hold Return [%]': '买入并持有收益率(%', "Buy & Hold Return [%]": "买入并持有收益率(%",
'Return (Ann.) [%]': '年化收益率(%', "Return (Ann.) [%]": "年化收益率(%",
'Volatility (Ann.) [%]': '年化波动率(%', "Volatility (Ann.) [%]": "年化波动率(%",
# 'CAGR [%]': '复合年均增长率(%', # 'CAGR [%]': '复合年均增长率(%',
# 'Sharpe Ratio': '夏普比率', # 'Sharpe Ratio': '夏普比率',
'Sortino Ratio': '索提诺比率', "Sortino Ratio": "索提诺比率",
'Calmar Ratio': '卡尔玛比率', "Calmar Ratio": "卡尔玛比率",
# 'Alpha [%]': '阿尔法系数(%', # 'Alpha [%]': '阿尔法系数(%',
# 'Beta': '贝塔系数', # 'Beta': '贝塔系数',
'Max. Drawdown [%]': '最大回撤(%', "Max. Drawdown [%]": "最大回撤(%",
'Avg. Drawdown [%]': '平均回撤(%', "Avg. Drawdown [%]": "平均回撤(%",
'Max. Drawdown Duration': '最大回撤持续时长', "Max. Drawdown Duration": "最大回撤持续时长",
'Avg. Drawdown Duration': '平均回撤持续时长', "Avg. Drawdown Duration": "平均回撤持续时长",
'# Trades': '总交易次数', "# Trades": "总交易次数",
'Win Rate [%]': '胜率(%', "Win Rate [%]": "胜率(%",
# 'Best Trade [%]': '最佳单笔交易收益率(%', # 'Best Trade [%]': '最佳单笔交易收益率(%',
# 'Worst Trade [%]': '最差单笔交易收益率(%', # 'Worst Trade [%]': '最差单笔交易收益率(%',
# 'Avg. Trade [%]': '平均单笔交易收益率(%', # 'Avg. Trade [%]': '平均单笔交易收益率(%',
@@ -240,14 +213,19 @@ def print_stats(stats):
# 'Avg. Trade Duration': '单笔交易平均持有时长', # 'Avg. Trade Duration': '单笔交易平均持有时长',
# 'Profit Factor': '盈利因子', # 'Profit Factor': '盈利因子',
# 'Expectancy [%]': '期望收益(%', # 'Expectancy [%]': '期望收益(%',
'SQN': '系统质量数', "SQN": "系统质量数",
# 'Kelly Criterion': '凯利准则', # 'Kelly Criterion': '凯利准则',
} }
for k, v in stats.items(): for k, v in stats.items():
if k in indicator_name_mapping: if k in indicator_name_mapping:
cn_name = indicator_name_mapping.get(k, k) cn_name = indicator_name_mapping.get(k, k)
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
if "%" in cn_name or k in ['Sharpe Ratio', 'Sortino Ratio', 'Calmar Ratio', 'Profit Factor']: if "%" in cn_name or k in [
"Sharpe Ratio",
"Sortino Ratio",
"Calmar Ratio",
"Profit Factor",
]:
formatted_value = f"{v:.2f}" formatted_value = f"{v:.2f}"
elif "$" in cn_name: elif "$" in cn_name:
formatted_value = f"{v:.2f}" formatted_value = f"{v:.2f}"
@@ -257,9 +235,9 @@ def print_stats(stats):
formatted_value = f"{v:.4f}" formatted_value = f"{v:.4f}"
else: else:
formatted_value = str(v) formatted_value = str(v)
print(f'{cn_name}: {formatted_value}') print(f"{cn_name}: {formatted_value}")
print("=" * 60 + "\n") print("=" * 60)
def main(): def main():
@@ -278,20 +256,16 @@ def main():
print(f"数据加载完成,共 {len(data)} 条记录") print(f"数据加载完成,共 {len(data)} 条记录")
# 截取预热数据 # 截取预热数据
warmup_data = data.iloc[-args.warmup_days :] warmup_data = data.iloc[-args.warmup_days:]
print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}") print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}")
# 加载策略 # 加载策略
print(f"加载策略: {args.strategy_file}")
calculate_indicators, strategy_class = load_strategy(args.strategy_file) calculate_indicators, strategy_class = load_strategy(args.strategy_file)
# 计算指标 # 计算指标
print("计算指标...")
warmup_data = calculate_indicators(warmup_data) warmup_data = calculate_indicators(warmup_data)
print("指标计算完成")
# 执行回测 # 执行回测
print("开始回测...")
from backtesting import Backtest from backtesting import Backtest
bt = Backtest( bt = Backtest(
@@ -308,12 +282,10 @@ def main():
# 生成图表 # 生成图表
if args.output: if args.output:
print(f"\n生成图表: {args.output}") os.makedirs(os.path.dirname(args.output), exist_ok=True)
bt.plot(filename=args.output, open_browser=False) bt.plot(filename=args.output, open_browser=False)
print(f"图表已保存到: {args.output}") print(f"图表已保存到: {args.output}")
print("\n回测完成!")
except Exception as e: except Exception as e:
print(f"\n错误: {e}") print(f"\n错误: {e}")
import traceback import traceback

View File

@@ -2,28 +2,27 @@
MACD 趋势跟踪策略 MACD 趋势跟踪策略
策略逻辑: 策略逻辑:
- 当 MACD 线上穿信号线时 (金叉),且价格 > EMA200 时,买入 - 当 MACD 线上穿信号线时 (金叉),且价格 > EMA 时,买入
- 当 MACD 线下穿信号线时 (死叉),或价格 < EMA200 时,卖出 - 当 MACD 线下穿信号线时 (死叉),或价格 < EMA 时,卖出
指标计算: 指标计算:
- MACD(10, 20, 9): 快线 10 日,慢线 20 日,信号线 9 日 - MACD(10, 20, 9): 快线 10 日,慢线 20 日,信号线 9 日
- EMA200: 200 日指数移动平均线(趋势确认) - EMA: 200 日指数移动平均线(趋势确认)
参数选择理由: 参数选择理由:
- 快线 10: 比标准 12 更敏感,适应 A 股较高波动性 - 快线 10: 比标准 12 更敏感,适应 A 股较高波动性
- 慢线 20: 比标准 26 更快响应,同时保持趋势跟踪稳定性 - 慢线 20: 比标准 26 更快响应,同时保持趋势跟踪稳定性
- 信号线 9: 保持标准,避免信号过于频繁 - 信号线 9: 保持标准,避免信号过于频繁
- EMA200: 被广泛认可为牛熊分界线,避免逆势交易 - EMA: 被广泛认可为牛熊分界线,避免逆势交易
趋势过滤: 趋势过滤:
- EMA200 上方: 确认为上升趋势,允许开多仓 - EMA 上方: 确认为上升趋势,允许开多仓
- EMA200 下方: 确认为下降趋势,不开多仓,强制平仓 - EMA 下方: 确认为下降趋势,不开多仓,强制平仓
Author: Sisyphus Author: Sisyphus
Date: 2025-01-27 Date: 2025-01-27
""" """
import pandas as pd
from backtesting import Strategy from backtesting import Strategy
from backtesting.lib import crossover from backtesting.lib import crossover
@@ -32,32 +31,30 @@ def calculate_indicators(data):
""" """
计算策略所需的技术指标 计算策略所需的技术指标
使用 ta-lib 库计算 MACD 和 EMA200 指标 使用 ta-lib 库计算 MACD 和 EMA 指标
参数: 参数:
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor] data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
返回: 返回:
DataFrame, 添加了指标列: DataFrame, 添加了指标列:
- MACD_10_20_9: MACD 线 (DIF) - dif: MACD 线 (dif)
- MACDs_10_20_9: MACD 信号线 (DEA) - signal: MACD 信号线 (DEA)
- MACDh_10_20_9: MACD 柱状图 (Histogram) - hist: MACD 柱状图 (Histogram)
- EMA_200: 200 日指数移动平均线 - ema: 日指数移动平均线
""" """
data = data.copy() data = data.copy()
# 计算 MACD 指标 (10, 20, 9) # 计算 MACD 指标 (10, 20, 9)
# talib.MACD 返回三个值: (macd, macdsignal, macdhist) # talib.MACD 返回三个值: (macd, macdsignal, macdhist)
macd, macdsignal, macdhist = talib.MACD( macd, macdsignal, macdhist = talib.MACD(data["Close"], fastperiod=10, slowperiod=20, signalperiod=9)
data["Close"], fastperiod=10, slowperiod=20, signalperiod=9
)
data["MACD_10_20_9"] = macd data["dif"] = macd
data["MACDs_10_20_9"] = macdsignal data["signal"] = macdsignal
data["MACDh_10_20_9"] = macdhist data["hist"] = macdhist
# 计算 EMA200 趋势线 # 计算 EMA 趋势线
data["EMA_200"] = talib.EMA(data["Close"], timeperiod=200) data["ema"] = talib.SMA(data["Close"], timeperiod=120)
return data return data
@@ -76,7 +73,7 @@ class MacdTrendStrategy(Strategy):
""" """
MACD 趋势跟踪策略 MACD 趋势跟踪策略
结合 MACD 金叉/死叉信号和 EMA200 趋势过滤 结合 MACD 金叉/死叉信号和 EMA 趋势过滤
参数: 参数:
fast_period: MACD 快线周期 (默认: 10) fast_period: MACD 快线周期 (默认: 10)
@@ -95,13 +92,13 @@ class MacdTrendStrategy(Strategy):
注册指标到 backtesting 框架 注册指标到 backtesting 框架
""" """
# 注册 MACD 线 # 注册 MACD 线
self.macd = self.I(lambda x: x, self.data.MACD_10_20_9) self.macd = self.I(lambda x: x, self.data.dif)
# 注册 MACD 信号线 # 注册 MACD 信号线
self.macd_signal = self.I(lambda x: x, self.data.MACDs_10_20_9) self.macd_signal = self.I(lambda x: x, self.data.signal)
# 注册 EMA200 趋势线 # 注册 EMA 趋势线
self.ema200 = self.I(lambda x: x, self.data.EMA_200) self.ema = self.I(lambda x: x, self.data.ema)
def next(self): def next(self):
""" """
@@ -109,25 +106,19 @@ class MacdTrendStrategy(Strategy):
买入条件: 买入条件:
- MACD 金叉 (MACD 线上穿信号线) - MACD 金叉 (MACD 线上穿信号线)
- 价格 > EMA200 (确认上升趋势) - 价格 > EMA (确认上升趋势)
卖出条件: 卖出条件:
- MACD 死叉 (MACD 线下穿信号线) - MACD 死叉 (MACD 线下穿信号线)
- 或价格 < EMA200 (趋势转向,强制平仓) - 或价格 < EMA (趋势转向,强制平仓)
""" """
# 买入条件: MACD 金叉 AND 价格 > EMA200 # 买入条件: MACD 金叉 AND 价格 > EMA
if ( if crossover(self.macd, self.macd_signal) and self.data.Close[-1] > self.ema[-1]:
crossover(self.macd, self.macd_signal)
and self.data.Close[-1] > self.ema200[-1]
):
self.position.close() # 先平掉现有仓位 self.position.close() # 先平掉现有仓位
self.buy() # 开多仓 self.buy() # 开多仓
# 卖出条件: MACD 死叉 OR 价格 < EMA200 # 卖出条件: MACD 死叉 OR 价格 < EMA
elif ( elif crossover(self.macd_signal, self.macd) or self.data.Close[-1] < self.ema[-1]:
crossover(self.macd_signal, self.macd)
or self.data.Close[-1] < self.ema200[-1]
):
self.position.close() # 平掉多仓 self.position.close() # 平掉多仓

View File

@@ -5,7 +5,7 @@ SMA 双均线交叉策略
- 当短期均线上穿长期均线时 (金叉),买入 - 当短期均线上穿长期均线时 (金叉),买入
- 当短期均线下穿长期均线时 (死叉),卖出 - 当短期均线下穿长期均线时 (死叉),卖出
指标计算: 指标计算 (使用 ta-lib):
- SMA10: 10 日简单移动平均线 - SMA10: 10 日简单移动平均线
- SMA30: 30 日简单移动平均线 - SMA30: 30 日简单移动平均线
- SMA60: 60 日简单移动平均线 - SMA60: 60 日简单移动平均线
@@ -21,19 +21,25 @@ def calculate_indicators(data):
""" """
计算策略所需的技术指标 计算策略所需的技术指标
使用 ta-lib 库计算 SMA 指标
参数: 参数:
data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor] data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor]
返回: 返回:
DataFrame, 添加了指标列 DataFrame, 添加了指标列:
- sma10: 10 日简单移动平均线
- sma30: 30 日简单移动平均线
- sma60: 60 日简单移动平均线
- sma120: 120 日简单移动平均线
""" """
data = data.copy() data = data.copy()
# 计算不同周期的移动平均线 # 计算不同周期的移动平均线
data["sma10"] = data["Close"].rolling(window=10).mean() data["sma10"] = talib.SMA(data["Close"], timeperiod=10)
data["sma30"] = data["Close"].rolling(window=30).mean() data["sma30"] = talib.SMA(data["Close"], timeperiod=30)
data["sma60"] = data["Close"].rolling(window=60).mean() data["sma60"] = talib.SMA(data["Close"], timeperiod=60)
data["sma120"] = data["Close"].rolling(window=120).mean() data["sma120"] = talib.SMA(data["Close"], timeperiod=120)
return data return data
@@ -85,3 +91,7 @@ class SmaCross(Strategy):
elif crossover(self.data.sma30, self.data.sma10): elif crossover(self.data.sma30, self.data.sma10):
self.position.close() # 先平掉现有仓位 self.position.close() # 先平掉现有仓位
self.sell() # 开空仓 self.sell() # 开空仓
# 导入 talib (必须在文件末尾,因为 calculate_indicators 函数中使用了 talib)
import talib