1
0
Files
leopard-analysis/backtest_core.py

209 lines
6.4 KiB
Python

import dataclasses
import importlib.util
import os
from typing import Optional
import pandas as pd
from tqdm import tqdm
import config
@dataclasses.dataclass
class BacktestResult:
code: str
equity_final: float
equity_peak: float
return_pct: float
buy_hold_return_pct: float
return_ann_pct: float
volatility_ann_pct: float
sortino_ratio: float
calmar_ratio: float
max_drawdown_pct: float
avg_drawdown_pct: float
max_drawdown_duration: float
avg_drawdown_duration: float
num_trades: int
win_rate_pct: float
sqn: float
def load_data_from_db(code: str, start_date: str, end_date: str) -> pd.DataFrame:
import sqlalchemy
import urllib.parse
encoded_password = urllib.parse.quote_plus(config.DB_PASSWORD)
conn_str = f"postgresql://{config.DB_USER}:{encoded_password}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
engine = sqlalchemy.create_engine(conn_str)
try:
query = f"""
SELECT trade_date,
open * factor AS "Open",
close * factor AS "Close",
high * factor AS "High",
low * factor AS "Low",
volume AS "Volume",
COALESCE(factor, 1.0) AS factor
FROM leopard_daily daily
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
WHERE stock.code = '{code}'
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
AND '{end_date} 23:59:59'
ORDER BY daily.trade_date
"""
df = pd.read_sql(query, engine)
if len(df) == 0:
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
df.set_index("trade_date", inplace=True)
return df
finally:
engine.dispose()
def load_strategy(strategy_file: str):
module_name = strategy_file.replace(".py", "").replace("/", ".")
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if not hasattr(module, "calculate_indicators"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
if not hasattr(module, "get_strategy"):
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
calculate_indicators = module.calculate_indicators
strategy_class = module.get_strategy()
if not isinstance(strategy_class, type):
raise TypeError("get_strategy() 必须返回一个类")
from backtesting import Strategy
if not issubclass(strategy_class, Strategy):
raise TypeError("策略类必须继承 backtesting.Strategy")
return calculate_indicators, strategy_class
def apply_color_scheme():
import backtesting._plotting as plotting
plotting.BULL_COLOR = config.BULL_COLOR
plotting.BEAR_COLOR = config.BEAR_COLOR
def run_backtest(
code: str,
start_date: str,
end_date: str,
strategy_file: str,
cash: float = config.DEFAULT_CASH,
commission: float = config.DEFAULT_COMMISSION,
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
output_dir: Optional[str] = None,
) -> BacktestResult:
warmup_start_date = (pd.to_datetime(start_date) - pd.Timedelta(days=warmup_days)).strftime("%Y-%m-%d")
data = load_data_from_db(code, warmup_start_date, end_date)
calculate_indicators, strategy_class = load_strategy(strategy_file)
data = calculate_indicators(data)
data = data.loc[start_date:end_date]
from backtesting import Backtest
bt = Backtest(data, strategy_class, cash=cash, commission=commission, finalize_trades=True)
stats = bt.run()
apply_color_scheme()
if output_dir:
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{code}.html")
bt.plot(filename=output_path, open_browser=False)
def _safe_float(value, default=0):
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def _safe_int(value, default=0):
if value is None:
return default
try:
return int(value)
except (TypeError, ValueError):
return default
def _safe_timedelta(value, default=0):
if value is None:
return default
try:
return float(value.total_seconds() / 86400)
except (TypeError, AttributeError):
return default
return BacktestResult(
code=code,
equity_final=_safe_float(stats.get("Equity Final [$]"), 0),
equity_peak=_safe_float(stats.get("Equity Peak [$]"), 0),
return_pct=_safe_float(stats.get("Return [%]"), 0),
buy_hold_return_pct=_safe_float(stats.get("Buy & Hold Return [%]"), 0),
return_ann_pct=_safe_float(stats.get("Return (Ann.) [%]"), 0),
volatility_ann_pct=_safe_float(stats.get("Volatility (Ann.) [%]"), 0),
sortino_ratio=_safe_float(stats.get("Sortino Ratio"), 0),
calmar_ratio=_safe_float(stats.get("Calmar Ratio"), 0),
max_drawdown_pct=_safe_float(stats.get("Max. Drawdown [%]"), 0),
avg_drawdown_pct=_safe_float(stats.get("Avg. Drawdown [%]"), 0),
max_drawdown_duration=_safe_timedelta(stats.get("Max. Drawdown Duration"), 0),
avg_drawdown_duration=_safe_timedelta(stats.get("Avg. Drawdown Duration"), 0),
num_trades=_safe_int(stats.get("# Trades"), 0),
win_rate_pct=_safe_float(stats.get("Win Rate [%]"), 0),
sqn=_safe_float(stats.get("SQN"), 0),
)
def run_batch_backtest(
codes: list[str],
start_date: str,
end_date: str,
strategy_file: str,
cash: float = config.DEFAULT_CASH,
commission: float = config.DEFAULT_COMMISSION,
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
output_dir: Optional[str] = None,
show_progress: bool = True,
) -> list[BacktestResult]:
results = []
codes_iter = tqdm(codes, desc="批量回测") if show_progress else codes
for code in codes_iter:
result = run_backtest(
code=code,
start_date=start_date,
end_date=end_date,
strategy_file=strategy_file,
cash=cash,
commission=commission,
warmup_days=warmup_days,
output_dir=output_dir,
)
results.append(result)
return results