重构回测代码架构,新增批量回测功能
This commit is contained in:
208
backtest_core.py
Normal file
208
backtest_core.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import dataclasses
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
import config
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BacktestResult:
|
||||
code: str
|
||||
equity_final: float
|
||||
equity_peak: float
|
||||
return_pct: float
|
||||
buy_hold_return_pct: float
|
||||
return_ann_pct: float
|
||||
volatility_ann_pct: float
|
||||
sortino_ratio: float
|
||||
calmar_ratio: float
|
||||
max_drawdown_pct: float
|
||||
avg_drawdown_pct: float
|
||||
max_drawdown_duration: float
|
||||
avg_drawdown_duration: float
|
||||
num_trades: int
|
||||
win_rate_pct: float
|
||||
sqn: float
|
||||
|
||||
|
||||
def load_data_from_db(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
||||
import sqlalchemy
|
||||
import urllib.parse
|
||||
|
||||
encoded_password = urllib.parse.quote_plus(config.DB_PASSWORD)
|
||||
conn_str = f"postgresql://{config.DB_USER}:{encoded_password}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
|
||||
engine = sqlalchemy.create_engine(conn_str)
|
||||
|
||||
try:
|
||||
query = f"""
|
||||
SELECT trade_date,
|
||||
open * factor AS "Open",
|
||||
close * factor AS "Close",
|
||||
high * factor AS "High",
|
||||
low * factor AS "Low",
|
||||
volume AS "Volume",
|
||||
COALESCE(factor, 1.0) AS factor
|
||||
FROM leopard_daily daily
|
||||
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
|
||||
WHERE stock.code = '{code}'
|
||||
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
|
||||
AND '{end_date} 23:59:59'
|
||||
ORDER BY daily.trade_date
|
||||
"""
|
||||
|
||||
df = pd.read_sql(query, engine)
|
||||
|
||||
if len(df) == 0:
|
||||
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
|
||||
|
||||
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
|
||||
df.set_index("trade_date", inplace=True)
|
||||
|
||||
return df
|
||||
|
||||
finally:
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def load_strategy(strategy_file: str):
|
||||
module_name = strategy_file.replace(".py", "").replace("/", ".")
|
||||
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
if not hasattr(module, "calculate_indicators"):
|
||||
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
|
||||
|
||||
if not hasattr(module, "get_strategy"):
|
||||
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
|
||||
|
||||
calculate_indicators = module.calculate_indicators
|
||||
strategy_class = module.get_strategy()
|
||||
|
||||
if not isinstance(strategy_class, type):
|
||||
raise TypeError("get_strategy() 必须返回一个类")
|
||||
|
||||
from backtesting import Strategy
|
||||
|
||||
if not issubclass(strategy_class, Strategy):
|
||||
raise TypeError("策略类必须继承 backtesting.Strategy")
|
||||
|
||||
return calculate_indicators, strategy_class
|
||||
|
||||
|
||||
def apply_color_scheme():
|
||||
import backtesting._plotting as plotting
|
||||
|
||||
plotting.BULL_COLOR = config.BULL_COLOR
|
||||
plotting.BEAR_COLOR = config.BEAR_COLOR
|
||||
|
||||
|
||||
def run_backtest(
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
strategy_file: str,
|
||||
cash: float = config.DEFAULT_CASH,
|
||||
commission: float = config.DEFAULT_COMMISSION,
|
||||
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
|
||||
output_dir: Optional[str] = None,
|
||||
) -> BacktestResult:
|
||||
warmup_start_date = (pd.to_datetime(start_date) - pd.Timedelta(days=warmup_days)).strftime("%Y-%m-%d")
|
||||
|
||||
data = load_data_from_db(code, warmup_start_date, end_date)
|
||||
|
||||
calculate_indicators, strategy_class = load_strategy(strategy_file)
|
||||
|
||||
data = calculate_indicators(data)
|
||||
|
||||
data = data.loc[start_date:end_date]
|
||||
|
||||
from backtesting import Backtest
|
||||
|
||||
bt = Backtest(data, strategy_class, cash=cash, commission=commission, finalize_trades=True)
|
||||
stats = bt.run()
|
||||
|
||||
apply_color_scheme()
|
||||
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = os.path.join(output_dir, f"{code}.html")
|
||||
bt.plot(filename=output_path, open_browser=False)
|
||||
|
||||
def _safe_float(value, default=0):
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _safe_int(value, default=0):
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _safe_timedelta(value, default=0):
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value.total_seconds() / 86400)
|
||||
except (TypeError, AttributeError):
|
||||
return default
|
||||
|
||||
return BacktestResult(
|
||||
code=code,
|
||||
equity_final=_safe_float(stats.get("Equity Final [$]"), 0),
|
||||
equity_peak=_safe_float(stats.get("Equity Peak [$]"), 0),
|
||||
return_pct=_safe_float(stats.get("Return [%]"), 0),
|
||||
buy_hold_return_pct=_safe_float(stats.get("Buy & Hold Return [%]"), 0),
|
||||
return_ann_pct=_safe_float(stats.get("Return (Ann.) [%]"), 0),
|
||||
volatility_ann_pct=_safe_float(stats.get("Volatility (Ann.) [%]"), 0),
|
||||
sortino_ratio=_safe_float(stats.get("Sortino Ratio"), 0),
|
||||
calmar_ratio=_safe_float(stats.get("Calmar Ratio"), 0),
|
||||
max_drawdown_pct=_safe_float(stats.get("Max. Drawdown [%]"), 0),
|
||||
avg_drawdown_pct=_safe_float(stats.get("Avg. Drawdown [%]"), 0),
|
||||
max_drawdown_duration=_safe_timedelta(stats.get("Max. Drawdown Duration"), 0),
|
||||
avg_drawdown_duration=_safe_timedelta(stats.get("Avg. Drawdown Duration"), 0),
|
||||
num_trades=_safe_int(stats.get("# Trades"), 0),
|
||||
win_rate_pct=_safe_float(stats.get("Win Rate [%]"), 0),
|
||||
sqn=_safe_float(stats.get("SQN"), 0),
|
||||
)
|
||||
|
||||
|
||||
def run_batch_backtest(
|
||||
codes: list[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
strategy_file: str,
|
||||
cash: float = config.DEFAULT_CASH,
|
||||
commission: float = config.DEFAULT_COMMISSION,
|
||||
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
|
||||
output_dir: Optional[str] = None,
|
||||
show_progress: bool = True,
|
||||
) -> list[BacktestResult]:
|
||||
results = []
|
||||
|
||||
codes_iter = tqdm(codes, desc="批量回测") if show_progress else codes
|
||||
|
||||
for code in codes_iter:
|
||||
result = run_backtest(
|
||||
code=code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
strategy_file=strategy_file,
|
||||
cash=cash,
|
||||
commission=commission,
|
||||
warmup_days=warmup_days,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user