import dataclasses import importlib.util import os from typing import Optional import pandas as pd from tqdm import tqdm import config @dataclasses.dataclass class BacktestResult: code: str equity_final: float equity_peak: float return_pct: float buy_hold_return_pct: float return_ann_pct: float volatility_ann_pct: float sortino_ratio: float calmar_ratio: float max_drawdown_pct: float avg_drawdown_pct: float max_drawdown_duration: float avg_drawdown_duration: float num_trades: int win_rate_pct: float sqn: float def load_data_from_db(code: str, start_date: str, end_date: str) -> pd.DataFrame: import sqlalchemy import urllib.parse encoded_password = urllib.parse.quote_plus(config.DB_PASSWORD) conn_str = f"postgresql://{config.DB_USER}:{encoded_password}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}" engine = sqlalchemy.create_engine(conn_str) try: query = f""" SELECT trade_date, open * factor AS "Open", close * factor AS "Close", high * factor AS "High", low * factor AS "Low", volume AS "Volume", COALESCE(factor, 1.0) AS factor FROM leopard_daily daily LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id WHERE stock.code = '{code}' AND daily.trade_date BETWEEN '{start_date} 00:00:00' AND '{end_date} 23:59:59' ORDER BY daily.trade_date """ df = pd.read_sql(query, engine) if len(df) == 0: raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据") df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d") df.set_index("trade_date", inplace=True) return df finally: engine.dispose() def load_strategy(strategy_file: str): module_name = strategy_file.replace(".py", "").replace("/", ".") spec = importlib.util.spec_from_file_location(module_name, strategy_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) if not hasattr(module, "calculate_indicators"): raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数") if not hasattr(module, "get_strategy"): raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数") calculate_indicators = module.calculate_indicators strategy_class = module.get_strategy() if not isinstance(strategy_class, type): raise TypeError("get_strategy() 必须返回一个类") from backtesting import Strategy if not issubclass(strategy_class, Strategy): raise TypeError("策略类必须继承 backtesting.Strategy") return calculate_indicators, strategy_class def apply_color_scheme(): import backtesting._plotting as plotting plotting.BULL_COLOR = config.BULL_COLOR plotting.BEAR_COLOR = config.BEAR_COLOR def run_backtest( code: str, start_date: str, end_date: str, strategy_file: str, cash: float = config.DEFAULT_CASH, commission: float = config.DEFAULT_COMMISSION, warmup_days: int = config.DEFAULT_WARMUP_DAYS, output_dir: Optional[str] = None, ) -> BacktestResult: warmup_start_date = (pd.to_datetime(start_date) - pd.Timedelta(days=warmup_days)).strftime("%Y-%m-%d") data = load_data_from_db(code, warmup_start_date, end_date) calculate_indicators, strategy_class = load_strategy(strategy_file) data = calculate_indicators(data) data = data.loc[start_date:end_date] from backtesting import Backtest bt = Backtest(data, strategy_class, cash=cash, commission=commission, finalize_trades=True) stats = bt.run() apply_color_scheme() if output_dir: os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, f"{code}.html") bt.plot(filename=output_path, open_browser=False) def _safe_float(value, default=0): if value is None: return default try: return float(value) except (TypeError, ValueError): return default def _safe_int(value, default=0): if value is None: return default try: return int(value) except (TypeError, ValueError): return default def _safe_timedelta(value, default=0): if value is None: return default try: return float(value.total_seconds() / 86400) except (TypeError, AttributeError): return default return BacktestResult( code=code, equity_final=_safe_float(stats.get("Equity Final [$]"), 0), equity_peak=_safe_float(stats.get("Equity Peak [$]"), 0), return_pct=_safe_float(stats.get("Return [%]"), 0), buy_hold_return_pct=_safe_float(stats.get("Buy & Hold Return [%]"), 0), return_ann_pct=_safe_float(stats.get("Return (Ann.) [%]"), 0), volatility_ann_pct=_safe_float(stats.get("Volatility (Ann.) [%]"), 0), sortino_ratio=_safe_float(stats.get("Sortino Ratio"), 0), calmar_ratio=_safe_float(stats.get("Calmar Ratio"), 0), max_drawdown_pct=_safe_float(stats.get("Max. Drawdown [%]"), 0), avg_drawdown_pct=_safe_float(stats.get("Avg. Drawdown [%]"), 0), max_drawdown_duration=_safe_timedelta(stats.get("Max. Drawdown Duration"), 0), avg_drawdown_duration=_safe_timedelta(stats.get("Avg. Drawdown Duration"), 0), num_trades=_safe_int(stats.get("# Trades"), 0), win_rate_pct=_safe_float(stats.get("Win Rate [%]"), 0), sqn=_safe_float(stats.get("SQN"), 0), ) def run_batch_backtest( codes: list[str], start_date: str, end_date: str, strategy_file: str, cash: float = config.DEFAULT_CASH, commission: float = config.DEFAULT_COMMISSION, warmup_days: int = config.DEFAULT_WARMUP_DAYS, output_dir: Optional[str] = None, show_progress: bool = True, ) -> list[BacktestResult]: results = [] codes_iter = tqdm(codes, desc="批量回测") if show_progress else codes for code in codes_iter: result = run_backtest( code=code, start_date=start_date, end_date=end_date, strategy_file=strategy_file, cash=cash, commission=commission, warmup_days=warmup_days, output_dir=output_dir, ) results.append(result) return results