209 lines
6.4 KiB
Python
209 lines
6.4 KiB
Python
import dataclasses
|
|
import importlib.util
|
|
import os
|
|
from typing import Optional
|
|
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
|
|
import config
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BacktestResult:
|
|
code: str
|
|
equity_final: float
|
|
equity_peak: float
|
|
return_pct: float
|
|
buy_hold_return_pct: float
|
|
return_ann_pct: float
|
|
volatility_ann_pct: float
|
|
sortino_ratio: float
|
|
calmar_ratio: float
|
|
max_drawdown_pct: float
|
|
avg_drawdown_pct: float
|
|
max_drawdown_duration: float
|
|
avg_drawdown_duration: float
|
|
num_trades: int
|
|
win_rate_pct: float
|
|
sqn: float
|
|
|
|
|
|
def load_data_from_db(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
import sqlalchemy
|
|
import urllib.parse
|
|
|
|
encoded_password = urllib.parse.quote_plus(config.DB_PASSWORD)
|
|
conn_str = f"postgresql://{config.DB_USER}:{encoded_password}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
|
|
engine = sqlalchemy.create_engine(conn_str)
|
|
|
|
try:
|
|
query = f"""
|
|
SELECT trade_date,
|
|
open * factor AS "Open",
|
|
close * factor AS "Close",
|
|
high * factor AS "High",
|
|
low * factor AS "Low",
|
|
volume AS "Volume",
|
|
COALESCE(factor, 1.0) AS factor
|
|
FROM leopard_daily daily
|
|
LEFT JOIN leopard_stock stock ON stock.id = daily.stock_id
|
|
WHERE stock.code = '{code}'
|
|
AND daily.trade_date BETWEEN '{start_date} 00:00:00'
|
|
AND '{end_date} 23:59:59'
|
|
ORDER BY daily.trade_date
|
|
"""
|
|
|
|
df = pd.read_sql(query, engine)
|
|
|
|
if len(df) == 0:
|
|
raise ValueError(f"未找到股票 {code} 在指定时间范围内的数据")
|
|
|
|
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y-%m-%d")
|
|
df.set_index("trade_date", inplace=True)
|
|
|
|
return df
|
|
|
|
finally:
|
|
engine.dispose()
|
|
|
|
|
|
def load_strategy(strategy_file: str):
|
|
module_name = strategy_file.replace(".py", "").replace("/", ".")
|
|
spec = importlib.util.spec_from_file_location(module_name, strategy_file)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
|
|
if not hasattr(module, "calculate_indicators"):
|
|
raise AttributeError(f"策略文件 {strategy_file} 缺少 calculate_indicators 函数")
|
|
|
|
if not hasattr(module, "get_strategy"):
|
|
raise AttributeError(f"策略文件 {strategy_file} 缺少 get_strategy 函数")
|
|
|
|
calculate_indicators = module.calculate_indicators
|
|
strategy_class = module.get_strategy()
|
|
|
|
if not isinstance(strategy_class, type):
|
|
raise TypeError("get_strategy() 必须返回一个类")
|
|
|
|
from backtesting import Strategy
|
|
|
|
if not issubclass(strategy_class, Strategy):
|
|
raise TypeError("策略类必须继承 backtesting.Strategy")
|
|
|
|
return calculate_indicators, strategy_class
|
|
|
|
|
|
def apply_color_scheme():
|
|
import backtesting._plotting as plotting
|
|
|
|
plotting.BULL_COLOR = config.BULL_COLOR
|
|
plotting.BEAR_COLOR = config.BEAR_COLOR
|
|
|
|
|
|
def run_backtest(
|
|
code: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
strategy_file: str,
|
|
cash: float = config.DEFAULT_CASH,
|
|
commission: float = config.DEFAULT_COMMISSION,
|
|
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
|
|
output_dir: Optional[str] = None,
|
|
) -> BacktestResult:
|
|
warmup_start_date = (pd.to_datetime(start_date) - pd.Timedelta(days=warmup_days)).strftime("%Y-%m-%d")
|
|
|
|
data = load_data_from_db(code, warmup_start_date, end_date)
|
|
|
|
calculate_indicators, strategy_class = load_strategy(strategy_file)
|
|
|
|
data = calculate_indicators(data)
|
|
|
|
data = data.loc[start_date:end_date]
|
|
|
|
from backtesting import Backtest
|
|
|
|
bt = Backtest(data, strategy_class, cash=cash, commission=commission, finalize_trades=True)
|
|
stats = bt.run()
|
|
|
|
apply_color_scheme()
|
|
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
output_path = os.path.join(output_dir, f"{code}.html")
|
|
bt.plot(filename=output_path, open_browser=False)
|
|
|
|
def _safe_float(value, default=0):
|
|
if value is None:
|
|
return default
|
|
try:
|
|
return float(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
def _safe_int(value, default=0):
|
|
if value is None:
|
|
return default
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
def _safe_timedelta(value, default=0):
|
|
if value is None:
|
|
return default
|
|
try:
|
|
return float(value.total_seconds() / 86400)
|
|
except (TypeError, AttributeError):
|
|
return default
|
|
|
|
return BacktestResult(
|
|
code=code,
|
|
equity_final=_safe_float(stats.get("Equity Final [$]"), 0),
|
|
equity_peak=_safe_float(stats.get("Equity Peak [$]"), 0),
|
|
return_pct=_safe_float(stats.get("Return [%]"), 0),
|
|
buy_hold_return_pct=_safe_float(stats.get("Buy & Hold Return [%]"), 0),
|
|
return_ann_pct=_safe_float(stats.get("Return (Ann.) [%]"), 0),
|
|
volatility_ann_pct=_safe_float(stats.get("Volatility (Ann.) [%]"), 0),
|
|
sortino_ratio=_safe_float(stats.get("Sortino Ratio"), 0),
|
|
calmar_ratio=_safe_float(stats.get("Calmar Ratio"), 0),
|
|
max_drawdown_pct=_safe_float(stats.get("Max. Drawdown [%]"), 0),
|
|
avg_drawdown_pct=_safe_float(stats.get("Avg. Drawdown [%]"), 0),
|
|
max_drawdown_duration=_safe_timedelta(stats.get("Max. Drawdown Duration"), 0),
|
|
avg_drawdown_duration=_safe_timedelta(stats.get("Avg. Drawdown Duration"), 0),
|
|
num_trades=_safe_int(stats.get("# Trades"), 0),
|
|
win_rate_pct=_safe_float(stats.get("Win Rate [%]"), 0),
|
|
sqn=_safe_float(stats.get("SQN"), 0),
|
|
)
|
|
|
|
|
|
def run_batch_backtest(
|
|
codes: list[str],
|
|
start_date: str,
|
|
end_date: str,
|
|
strategy_file: str,
|
|
cash: float = config.DEFAULT_CASH,
|
|
commission: float = config.DEFAULT_COMMISSION,
|
|
warmup_days: int = config.DEFAULT_WARMUP_DAYS,
|
|
output_dir: Optional[str] = None,
|
|
show_progress: bool = True,
|
|
) -> list[BacktestResult]:
|
|
results = []
|
|
|
|
codes_iter = tqdm(codes, desc="批量回测") if show_progress else codes
|
|
|
|
for code in codes_iter:
|
|
result = run_backtest(
|
|
code=code,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
strategy_file=strategy_file,
|
|
cash=cash,
|
|
commission=commission,
|
|
warmup_days=warmup_days,
|
|
output_dir=output_dir,
|
|
)
|
|
results.append(result)
|
|
|
|
return results
|