1
0

feat: 增加简单回测

This commit is contained in:
2025-10-16 23:19:56 +08:00
parent e2c5729f87
commit 83574e1229
2 changed files with 112 additions and 10 deletions

View File

@@ -1,11 +1,16 @@
package com.lanyuanxiaoyao.leopard.core.strategy;
import cn.hutool.core.util.ObjectUtil;
import com.lanyuanxiaoyao.leopard.core.entity.Daily;
import com.lanyuanxiaoyao.leopard.core.entity.QDaily;
import com.lanyuanxiaoyao.leopard.core.repository.DailyRepository;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@@ -24,23 +29,80 @@ public class TradeEngine {
this.dailyRepository = dailyRepository;
}
public void backtest(List<Long> stocks, TradeStrategy strategy) {
public Asset backtest(List<Long> stocks, TradeStrategy strategy, LocalDate startDate, LocalDate endDate) {
var dailies = dailyRepository.findAll(
QDaily.daily.stock.id.in(stocks)
.and(QDaily.daily.tradeDate.before(endDate)),
QDaily.daily.tradeDate.asc()
);
var validTradeDates = dailies.stream()
.map(Daily::getTradeDate)
.distinct()
.toList();
var asset = new Asset();
for (var now = startDate; now.isBefore(endDate) || now.isEqual(endDate); now = now.plusDays(1)) {
if (!validTradeDates.contains(now)) {
continue;
}
final var currentDate = now;
var trades = strategy.trade(
now,
asset,
dailies.stream()
.filter(daily -> daily.getTradeDate().isBefore(currentDate))
.collect(Collectors.groupingBy(daily -> daily.getStock().getId()))
);
for (var trade : trades) {
dailies.stream()
.filter(daily -> ObjectUtil.equals(daily.getStock().getId(), trade.stockId))
.filter(daily -> ObjectUtil.equals(daily.getTradeDate(), currentDate))
.findFirst()
.map(Daily::getHfqClose)
.ifPresent(close -> {
if (trade.volume < 0) {
asset.setCash(asset.getCash() + Math.abs(trade.volume) * close);
} else if (trade.volume > 0) {
asset.setCash(asset.getCash() - Math.abs(trade.volume) * close);
}
asset.getStocks().put(
trade.stockId,
asset.getStocks().getOrDefault(trade.stockId, 0) + trade.volume
);
});
}
asset.getHistories().add(new Asset.History(
now,
asset.getCash(),
asset.getStocks(),
trades.stream()
.collect(Collectors.groupingBy(trade -> trade.stockId))
));
}
return asset;
}
public interface TradeStrategy {
List<Trade> trade(LocalDate now, Asset asset, Map<Long, List<Daily>> dailies);
}
public record Asset(
Double cash,
Map<Long, Double> stocks
) {
public Asset() {
this(0.0, new HashMap<>());
@Data
public static final class Asset {
private double cash = 0;
private double profit = 0.0;
private Map<Long, Integer> stocks = new HashMap<>();
private List<History> histories = new ArrayList<>();
public record History(
LocalDate date,
double cash,
Map<Long, Integer> stocks,
Map<Long, List<Trade>> trades
) {
}
}
public record Trade(
LocalDate date,
Long stockId,
// 用正负数表达买卖
Integer volume