From 5cc140259ed6c235db84ec0e3a733db51f026754 Mon Sep 17 00:00:00 2001 From: lanyuanxiaoyao Date: Wed, 28 Jan 2026 09:46:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=A0=81=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backtest.py | 112 ++++++++++++++---------------------- strategies/macd_strategy.py | 65 +++++++++------------ strategies/sma_strategy.py | 22 +++++-- 3 files changed, 86 insertions(+), 113 deletions(-) diff --git a/backtest.py b/backtest.py index 3bc850b..a653a87 100644 --- a/backtest.py +++ b/backtest.py @@ -7,12 +7,11 @@ """ import argparse -import sys -import os import importlib.util +import os +import sys + import pandas as pd -from datetime import datetime -from backtesting import Backtest # 数据库配置(直接硬编码,开发环境) DB_HOST = "81.71.3.24" @@ -47,20 +46,19 @@ def load_data_from_db(code, start_date, end_date): try: # 构建 SQL 查询 query = f""" - SELECT - trade_date, - open * factor AS "Open", - close * factor AS "Close", - high * factor AS "High", - low * factor AS "Low", - volume AS "Volume", - COALESCE(factor, 1.0) AS factor - FROM leopard_daily daily - LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id - WHERE stock.code = '{code}' - AND daily.trade_date BETWEEN '{start_date} 00:00:00' - AND '{end_date} 23:59:59' - ORDER BY daily.trade_date +SELECT trade_date, + open * factor AS "Open", + close * factor AS "Close", + high * factor AS "High", + low * factor AS "Low", + volume AS "Volume", + COALESCE(factor, 1.0) AS factor +FROM leopard_daily daily + LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id +WHERE stock.code = '{code}' + AND daily.trade_date BETWEEN '{start_date} 00:00:00' + AND '{end_date} 23:59:59' +ORDER BY daily.trade_date """ # 执行查询 @@ -175,30 +173,6 @@ def parse_arguments(): return parser.parse_args() -def format_value(value, cn_name, key): - """ - 格式化数值显示 - """ - if isinstance(value, (int, float)): - if "%" in cn_name or key in [ - "Sharpe Ratio", - "Sortino Ratio", - "Calmar Ratio", - "Profit Factor", - ]: - formatted_value = f"{value:.2f}" - elif "$" in cn_name: - formatted_value = f"{value:.2f}" - elif "次数" in cn_name: - formatted_value = f"{value:.0f}" - else: - formatted_value = f"{value:.4f}" - else: - formatted_value = str(value) - - return formatted_value - - def print_stats(stats): """ 打印回测统计结果 @@ -206,33 +180,32 @@ def print_stats(stats): 参数: stats: backtesting 库返回的统计对象 """ - print("\n" + "=" * 60) - print("回测结果") print("=" * 60) + print("回测结果") indicator_name_mapping = { # 'Start': '回测开始时间', # 'End': '回测结束时间', # 'Duration': '回测持续时长', # 'Exposure Time [%]': '持仓时间占比(%)', - 'Equity Final [$]': '最终收益', - 'Equity Peak [$]': '峰值收益', - 'Return [%]': '总收益率(%)', - 'Buy & Hold Return [%]': '买入并持有收益率(%)', - 'Return (Ann.) [%]': '年化收益率(%)', - 'Volatility (Ann.) [%]': '年化波动率(%)', + "Equity Final [$]": "最终收益", + "Equity Peak [$]": "峰值收益", + "Return [%]": "总收益率(%)", + "Buy & Hold Return [%]": "买入并持有收益率(%)", + "Return (Ann.) [%]": "年化收益率(%)", + "Volatility (Ann.) [%]": "年化波动率(%)", # 'CAGR [%]': '复合年均增长率(%)', # 'Sharpe Ratio': '夏普比率', - 'Sortino Ratio': '索提诺比率', - 'Calmar Ratio': '卡尔玛比率', + "Sortino Ratio": "索提诺比率", + "Calmar Ratio": "卡尔玛比率", # 'Alpha [%]': '阿尔法系数(%)', # 'Beta': '贝塔系数', - 'Max. Drawdown [%]': '最大回撤(%)', - 'Avg. Drawdown [%]': '平均回撤(%)', - 'Max. Drawdown Duration': '最大回撤持续时长', - 'Avg. Drawdown Duration': '平均回撤持续时长', - '# Trades': '总交易次数', - 'Win Rate [%]': '胜率(%)', + "Max. Drawdown [%]": "最大回撤(%)", + "Avg. Drawdown [%]": "平均回撤(%)", + "Max. Drawdown Duration": "最大回撤持续时长", + "Avg. Drawdown Duration": "平均回撤持续时长", + "# Trades": "总交易次数", + "Win Rate [%]": "胜率(%)", # 'Best Trade [%]': '最佳单笔交易收益率(%)', # 'Worst Trade [%]': '最差单笔交易收益率(%)', # 'Avg. Trade [%]': '平均单笔交易收益率(%)', @@ -240,14 +213,19 @@ def print_stats(stats): # 'Avg. Trade Duration': '单笔交易平均持有时长', # 'Profit Factor': '盈利因子', # 'Expectancy [%]': '期望收益(%)', - 'SQN': '系统质量数', + "SQN": "系统质量数", # 'Kelly Criterion': '凯利准则', } for k, v in stats.items(): if k in indicator_name_mapping: cn_name = indicator_name_mapping.get(k, k) if isinstance(v, (int, float)): - if "%" in cn_name or k in ['Sharpe Ratio', 'Sortino Ratio', 'Calmar Ratio', 'Profit Factor']: + if "%" in cn_name or k in [ + "Sharpe Ratio", + "Sortino Ratio", + "Calmar Ratio", + "Profit Factor", + ]: formatted_value = f"{v:.2f}" elif "$" in cn_name: formatted_value = f"{v:.2f}" @@ -257,9 +235,9 @@ def print_stats(stats): formatted_value = f"{v:.4f}" else: formatted_value = str(v) - print(f'{cn_name}: {formatted_value}') + print(f"{cn_name}: {formatted_value}") - print("=" * 60 + "\n") + print("=" * 60) def main(): @@ -278,20 +256,16 @@ def main(): print(f"数据加载完成,共 {len(data)} 条记录") # 截取预热数据 - warmup_data = data.iloc[-args.warmup_days :] + warmup_data = data.iloc[-args.warmup_days:] print(f"使用预热数据范围: {warmup_data.index[0]} ~ {warmup_data.index[-1]}") # 加载策略 - print(f"加载策略: {args.strategy_file}") calculate_indicators, strategy_class = load_strategy(args.strategy_file) # 计算指标 - print("计算指标...") warmup_data = calculate_indicators(warmup_data) - print("指标计算完成") # 执行回测 - print("开始回测...") from backtesting import Backtest bt = Backtest( @@ -308,12 +282,10 @@ def main(): # 生成图表 if args.output: - print(f"\n生成图表: {args.output}") + os.makedirs(os.path.dirname(args.output), exist_ok=True) bt.plot(filename=args.output, open_browser=False) print(f"图表已保存到: {args.output}") - print("\n回测完成!") - except Exception as e: print(f"\n错误: {e}") import traceback diff --git a/strategies/macd_strategy.py b/strategies/macd_strategy.py index d570c8c..4f77bf5 100644 --- a/strategies/macd_strategy.py +++ b/strategies/macd_strategy.py @@ -2,28 +2,27 @@ MACD 趋势跟踪策略 策略逻辑: -- 当 MACD 线上穿信号线时 (金叉),且价格 > EMA200 时,买入 -- 当 MACD 线下穿信号线时 (死叉),或价格 < EMA200 时,卖出 +- 当 MACD 线上穿信号线时 (金叉),且价格 > EMA 时,买入 +- 当 MACD 线下穿信号线时 (死叉),或价格 < EMA 时,卖出 指标计算: - MACD(10, 20, 9): 快线 10 日,慢线 20 日,信号线 9 日 -- EMA200: 200 日指数移动平均线(趋势确认) +- EMA: 200 日指数移动平均线(趋势确认) 参数选择理由: - 快线 10: 比标准 12 更敏感,适应 A 股较高波动性 - 慢线 20: 比标准 26 更快响应,同时保持趋势跟踪稳定性 - 信号线 9: 保持标准,避免信号过于频繁 -- EMA200: 被广泛认可为牛熊分界线,避免逆势交易 +- EMA: 被广泛认可为牛熊分界线,避免逆势交易 趋势过滤: -- EMA200 上方: 确认为上升趋势,允许开多仓 -- EMA200 下方: 确认为下降趋势,不开多仓,强制平仓 +- EMA 上方: 确认为上升趋势,允许开多仓 +- EMA 下方: 确认为下降趋势,不开多仓,强制平仓 Author: Sisyphus Date: 2025-01-27 """ -import pandas as pd from backtesting import Strategy from backtesting.lib import crossover @@ -32,32 +31,30 @@ def calculate_indicators(data): """ 计算策略所需的技术指标 - 使用 ta-lib 库计算 MACD 和 EMA200 指标 + 使用 ta-lib 库计算 MACD 和 EMA 指标 参数: data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor] 返回: DataFrame, 添加了指标列: - - MACD_10_20_9: MACD 线 (DIF) - - MACDs_10_20_9: MACD 信号线 (DEA) - - MACDh_10_20_9: MACD 柱状图 (Histogram) - - EMA_200: 200 日指数移动平均线 + - dif: MACD 线 (dif) + - signal: MACD 信号线 (DEA) + - hist: MACD 柱状图 (Histogram) + - ema: 日指数移动平均线 """ data = data.copy() # 计算 MACD 指标 (10, 20, 9) # talib.MACD 返回三个值: (macd, macdsignal, macdhist) - macd, macdsignal, macdhist = talib.MACD( - data["Close"], fastperiod=10, slowperiod=20, signalperiod=9 - ) + macd, macdsignal, macdhist = talib.MACD(data["Close"], fastperiod=10, slowperiod=20, signalperiod=9) - data["MACD_10_20_9"] = macd - data["MACDs_10_20_9"] = macdsignal - data["MACDh_10_20_9"] = macdhist + data["dif"] = macd + data["signal"] = macdsignal + data["hist"] = macdhist - # 计算 EMA200 趋势线 - data["EMA_200"] = talib.EMA(data["Close"], timeperiod=200) + # 计算 EMA 趋势线 + data["ema"] = talib.SMA(data["Close"], timeperiod=120) return data @@ -76,7 +73,7 @@ class MacdTrendStrategy(Strategy): """ MACD 趋势跟踪策略 - 结合 MACD 金叉/死叉信号和 EMA200 趋势过滤 + 结合 MACD 金叉/死叉信号和 EMA 趋势过滤 参数: fast_period: MACD 快线周期 (默认: 10) @@ -95,13 +92,13 @@ class MacdTrendStrategy(Strategy): 注册指标到 backtesting 框架 """ # 注册 MACD 线 - self.macd = self.I(lambda x: x, self.data.MACD_10_20_9) + self.macd = self.I(lambda x: x, self.data.dif) # 注册 MACD 信号线 - self.macd_signal = self.I(lambda x: x, self.data.MACDs_10_20_9) + self.macd_signal = self.I(lambda x: x, self.data.signal) - # 注册 EMA200 趋势线 - self.ema200 = self.I(lambda x: x, self.data.EMA_200) + # 注册 EMA 趋势线 + self.ema = self.I(lambda x: x, self.data.ema) def next(self): """ @@ -109,25 +106,19 @@ class MacdTrendStrategy(Strategy): 买入条件: - MACD 金叉 (MACD 线上穿信号线) - - 价格 > EMA200 (确认上升趋势) + - 价格 > EMA (确认上升趋势) 卖出条件: - MACD 死叉 (MACD 线下穿信号线) - - 或价格 < EMA200 (趋势转向,强制平仓) + - 或价格 < EMA (趋势转向,强制平仓) """ - # 买入条件: MACD 金叉 AND 价格 > EMA200 - if ( - crossover(self.macd, self.macd_signal) - and self.data.Close[-1] > self.ema200[-1] - ): + # 买入条件: MACD 金叉 AND 价格 > EMA + if crossover(self.macd, self.macd_signal) and self.data.Close[-1] > self.ema[-1]: self.position.close() # 先平掉现有仓位 self.buy() # 开多仓 - # 卖出条件: MACD 死叉 OR 价格 < EMA200 - elif ( - crossover(self.macd_signal, self.macd) - or self.data.Close[-1] < self.ema200[-1] - ): + # 卖出条件: MACD 死叉 OR 价格 < EMA + elif crossover(self.macd_signal, self.macd) or self.data.Close[-1] < self.ema[-1]: self.position.close() # 平掉多仓 diff --git a/strategies/sma_strategy.py b/strategies/sma_strategy.py index 627e6fa..87bddee 100644 --- a/strategies/sma_strategy.py +++ b/strategies/sma_strategy.py @@ -5,7 +5,7 @@ SMA 双均线交叉策略 - 当短期均线上穿长期均线时 (金叉),买入 - 当短期均线下穿长期均线时 (死叉),卖出 -指标计算: +指标计算 (使用 ta-lib): - SMA10: 10 日简单移动平均线 - SMA30: 30 日简单移动平均线 - SMA60: 60 日简单移动平均线 @@ -21,19 +21,25 @@ def calculate_indicators(data): """ 计算策略所需的技术指标 + 使用 ta-lib 库计算 SMA 指标 + 参数: data: DataFrame, 包含 [Open, High, Low, Close, Volume, factor] 返回: - DataFrame, 添加了指标列 + DataFrame, 添加了指标列: + - sma10: 10 日简单移动平均线 + - sma30: 30 日简单移动平均线 + - sma60: 60 日简单移动平均线 + - sma120: 120 日简单移动平均线 """ data = data.copy() # 计算不同周期的移动平均线 - data["sma10"] = data["Close"].rolling(window=10).mean() - data["sma30"] = data["Close"].rolling(window=30).mean() - data["sma60"] = data["Close"].rolling(window=60).mean() - data["sma120"] = data["Close"].rolling(window=120).mean() + data["sma10"] = talib.SMA(data["Close"], timeperiod=10) + data["sma30"] = talib.SMA(data["Close"], timeperiod=30) + data["sma60"] = talib.SMA(data["Close"], timeperiod=60) + data["sma120"] = talib.SMA(data["Close"], timeperiod=120) return data @@ -85,3 +91,7 @@ class SmaCross(Strategy): elif crossover(self.data.sma30, self.data.sma10): self.position.close() # 先平掉现有仓位 self.sell() # 开空仓 + + +# 导入 talib (必须在文件末尾,因为 calculate_indicators 函数中使用了 talib) +import talib