1
0

feat(cache): 实现 RoutingCache 和 StatsBuffer 优化数据库写入

- 新增 RoutingCache 组件,使用 sync.Map 缓存 Provider 和 Model
- 新增 StatsBuffer 组件,使用 sync.Map + atomic.Int64 缓冲统计数据
- 扩展 StatsRepository.BatchUpdate 支持批量增量更新
- 改造 RoutingService/StatsService/ProviderService/ModelService 集成缓存
- 更新 usage-statistics spec,新增 routing-cache 和 stats-buffer spec
- 新增单元测试覆盖缓存命中/失效/并发场景
This commit is contained in:
2026-04-22 19:24:36 +08:00
parent f5e45d032e
commit df253559a5
20 changed files with 1377 additions and 91 deletions

View File

@@ -67,13 +67,25 @@ func main() {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
// 5. 初始化 service 层
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
// 5. 初始化缓存
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
if err := routingCache.Preload(); err != nil {
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
}
// 6. 创建 ConversionEngine
// 6. 初始化统计缓冲
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
service.WithFlushInterval(5*time.Second),
service.WithFlushThreshold(100))
statsBuffer.Start()
// 7. 初始化 service 层
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 8. 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
@@ -83,16 +95,16 @@ func main() {
}
engine := conversion.NewConversionEngine(registry, zapLogger)
// 7. 初始化 provider client
// 9. 初始化 provider client
providerClient := provider.NewClient()
// 8. 初始化 handler 层
// 10. 初始化 handler 层
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
// 9. 创建 Gin 引擎
// 11. 创建 Gin 引擎
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -103,7 +115,7 @@ func main() {
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
// 10. 启动服务器
// 12. 启动服务器
srv := &http.Server{
Addr: formatAddr(cfg.Server.Port),
Handler: r,
@@ -131,6 +143,8 @@ func main() {
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
}
statsBuffer.Stop()
zapLogger.Info("服务器已关闭")
}

View File

@@ -11,5 +11,6 @@ import (
// StatsRepository 统计数据仓库接口
type StatsRepository interface {
Record(providerID, modelName string) error
BatchUpdate(providerID, modelName string, date time.Time, delta int) error
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
}

View File

@@ -43,6 +43,28 @@ func (r *statsRepository) Record(providerID, modelName string) error {
})
}
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, date).First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return tx.Create(&config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}).Error
} else if err != nil {
return err
}
return tx.Model(&stats).
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
})
}
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
var stats []config.UsageStats
query := r.db.Model(&config.UsageStats{})

View File

@@ -11,27 +11,30 @@ import (
type modelService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache}
}
func (s *modelService) Create(model *domain.Model) error {
// 校验供应商存在
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
return appErrors.ErrProviderNotFound
}
// 联合唯一校验:同一供应商下 model_name 不重复
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
return err
}
// 自动生成 UUID 作为 id
model.ID = uuid.New().String()
model.Enabled = true
return s.modelRepo.Create(model)
err := s.modelRepo.Create(model)
if err != nil {
return err
}
s.cache.SetModel(model)
return nil
}
func (s *modelService) Get(id string) (*domain.Model, error) {
@@ -47,20 +50,17 @@ func (s *modelService) ListEnabled() ([]domain.Model, error) {
}
func (s *modelService) Update(id string, updates map[string]interface{}) error {
// 获取当前模型
current, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
// 如果更新 provider_id校验新供应商存在
if providerID, ok := updates["provider_id"].(string); ok {
if _, err := s.providerRepo.GetByID(providerID); err != nil {
return appErrors.ErrProviderNotFound
}
}
// 确定更新后的 provider_id 和 model_name
newProviderID := current.ProviderID
if v, ok := updates["provider_id"].(string); ok {
newProviderID = v
@@ -70,18 +70,37 @@ func (s *modelService) Update(id string, updates map[string]interface{}) error {
newModelName = v
}
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
if newProviderID != current.ProviderID || newModelName != current.ModelName {
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
return err
}
}
return s.modelRepo.Update(id, updates)
err = s.modelRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateModel(current.ProviderID, current.ModelName)
if newProviderID != current.ProviderID || newModelName != current.ModelName {
s.cache.InvalidateModel(newProviderID, newModelName)
}
return nil
}
func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id)
model, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
err = s.modelRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateModel(model.ProviderID, model.ModelName)
return nil
}
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复

View File

@@ -13,23 +13,27 @@ import (
type providerService struct {
providerRepo repository.ProviderRepository
modelRepo repository.ModelRepository
cache *RoutingCache
}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache}
}
func (s *providerService) Create(provider *domain.Provider) error {
// 校验 provider_id 字符集
if err := modelid.ValidateProviderID(provider.ID); err != nil {
return appErrors.ErrInvalidProviderID
}
provider.Enabled = true
err := s.providerRepo.Create(provider)
if err != nil && isUniqueConstraintError(err) {
return appErrors.ErrConflict
if err != nil {
if isUniqueConstraintError(err) {
return appErrors.ErrConflict
}
return err
}
return err
s.cache.SetProvider(provider)
return nil
}
func (s *providerService) Get(id string) (*domain.Provider, error) {
@@ -44,11 +48,21 @@ func (s *providerService) Update(id string, updates map[string]interface{}) erro
if _, ok := updates["id"]; ok {
return appErrors.ErrImmutableField
}
return s.providerRepo.Update(id, updates)
err := s.providerRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id)
err := s.providerRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)

View File

@@ -0,0 +1,134 @@
package service
import (
"strings"
"sync"
"go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type RoutingCache struct {
providers sync.Map
models sync.Map
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
logger *zap.Logger
}
func NewRoutingCache(
modelRepo repository.ModelRepository,
providerRepo repository.ProviderRepository,
logger *zap.Logger,
) *RoutingCache {
return &RoutingCache{
modelRepo: modelRepo,
providerRepo: providerRepo,
logger: logger,
}
}
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
}
provider, err := c.providerRepo.GetByID(id)
if err != nil {
return nil, err
}
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
}
c.providers.Store(id, provider)
return provider, nil
}
func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
}
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil, err
}
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
}
c.models.Store(key, model)
return model, nil
}
func (c *RoutingCache) SetProvider(provider *domain.Provider) {
c.providers.Store(provider.ID, provider)
}
func (c *RoutingCache) SetModel(model *domain.Model) {
key := model.ProviderID + "/" + model.ModelName
c.models.Store(key, model)
}
func (c *RoutingCache) InvalidateProvider(id string) {
c.providers.Delete(id)
c.invalidateModelsByProvider(id)
c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id))
}
func (c *RoutingCache) InvalidateModel(providerID, modelName string) {
key := providerID + "/" + modelName
c.models.Delete(key)
c.logger.Debug("Model 缓存失效",
zap.String("provider_id", providerID),
zap.String("model_name", modelName))
}
func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/"
count := 0
c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) {
c.models.Delete(key)
count++
}
return true
})
if count > 0 {
c.logger.Debug("清除 Provider 相关 Model 缓存",
zap.String("provider_id", providerID),
zap.Int("count", count))
}
}
func (c *RoutingCache) Preload() error {
providers, err := c.providerRepo.List()
if err != nil {
return err
}
for i := range providers {
c.providers.Store(providers[i].ID, &providers[i])
}
models, err := c.modelRepo.List("")
if err != nil {
return err
}
for i := range models {
key := models[i].ProviderID + "/" + models[i].ModelName
c.models.Store(key, &models[i])
}
c.logger.Info("缓存预热完成",
zap.Int("providers", len(providers)),
zap.Int("models", len(models)))
return nil
}

View File

@@ -0,0 +1,273 @@
package service
import (
"errors"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockModelRepo struct {
models map[string]*domain.Model
}
func (m *mockModelRepo) Create(model *domain.Model) error {
return nil
}
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if model, ok := m.models[key]; ok {
return model, nil
}
return nil, errors.New("not found")
}
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
var result []domain.Model
for _, model := range m.models {
if providerID == "" || model.ProviderID == providerID {
result = append(result, *model)
}
}
return result, nil
}
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockModelRepo) Delete(id string) error {
return nil
}
type mockProviderRepo struct {
providers map[string]*domain.Provider
}
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
return nil
}
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
if provider, ok := m.providers[id]; ok {
return provider, nil
}
return nil, errors.New("not found")
}
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
var result []domain.Provider
for _, provider := range m.providers {
result = append(result, *provider)
}
return result, nil
}
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockProviderRepo) Delete(id string) error {
return nil
}
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
provider, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
provider2, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, provider, provider2)
}
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetProvider("notexist")
assert.Error(t, err)
}
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
model, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
model2, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, model, model2)
}
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "notexist")
assert.Error(t, err)
}
func TestRoutingCache_DoubleCheck(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := cache.GetProvider("openai")
assert.NoError(t, err)
}()
}
wg.Wait()
}
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
_, err = cache.GetModel("openai", "gpt-3.5")
require.NoError(t, err)
_, err = cache.GetModel("anthropic", "claude")
require.NoError(t, err)
cache.InvalidateProvider("openai")
var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" {
anthropicCount++
}
return true
})
assert.Equal(t, 0, openaiCount)
assert.Equal(t, 1, anthropicCount)
}
func TestRoutingCache_InvalidateModel(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
cache.InvalidateModel("openai", "gpt-4")
var count int
cache.models.Range(func(key, value interface{}) bool {
count++
return true
})
assert.Equal(t, 0, count)
}
func TestRoutingCache_Preload_Success(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
err := cache.Preload()
require.NoError(t, err)
var providerCount, modelCount int
cache.providers.Range(func(key, value interface{}) bool {
providerCount++
return true
})
cache.models.Range(func(key, value interface{}) bool {
modelCount++
return true
})
assert.Equal(t, 1, providerCount)
assert.Equal(t, 1, modelCount)
}
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetProvider("openai")
_, _ = cache.GetModel("openai", "gpt-4")
cache.InvalidateProvider("openai")
cache.InvalidateModel("openai", "gpt-4")
}()
}
wg.Wait()
}

View File

@@ -4,20 +4,18 @@ import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type routingService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewRoutingService(cache *RoutingCache) RoutingService {
return &routingService{cache: cache}
}
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
model, err := s.cache.GetModel(providerID, modelName)
if err != nil {
return nil, appErrors.ErrModelNotFound
}
@@ -26,7 +24,7 @@ func (s *routingService) RouteByModelName(providerID, modelName string) (*domain
return nil, appErrors.ErrModelDisabled
}
provider, err := s.providerRepo.GetByID(model.ProviderID)
provider, err := s.cache.GetProvider(model.ProviderID)
if err != nil {
return nil, appErrors.ErrProviderNotFound
}

View File

@@ -14,7 +14,8 @@ func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
@@ -30,7 +31,8 @@ func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err)
@@ -40,7 +42,8 @@ func TestModelService_Get(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -55,7 +58,8 @@ func TestModelService_Update(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -73,7 +77,8 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -87,7 +92,8 @@ func TestModelService_Delete(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -104,7 +110,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Delete("nonexistent")
assert.Error(t, err)
@@ -112,7 +119,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
func TestStatsService_Aggregate_Default(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, nil)
svc := NewStatsService(statsRepo, buffer)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
@@ -133,7 +141,8 @@ func TestModelService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
assert.Error(t, err)

View File

@@ -8,6 +8,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
testHelpers "nex/backend/tests"
@@ -22,13 +23,21 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
return testHelpers.SetupTestDB(t)
}
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
t.Helper()
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
}
// ============ RoutingService - RouteByModelName 测试 ============
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
@@ -41,9 +50,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
_, err := svc.RouteByModelName("openai", "nonexistent-model")
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
@@ -53,7 +61,8 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
@@ -67,7 +76,8 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
@@ -83,7 +93,8 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -91,12 +102,10 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
err := svc.Create(model)
require.NoError(t, err)
// 验证返回的 model 拥有有效的 UUID
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
// 通过 Get 验证持久化
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
@@ -107,7 +116,8 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -115,7 +125,6 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
err := svc.Create(model1)
require.NoError(t, err)
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
@@ -125,7 +134,8 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
@@ -138,7 +148,8 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -149,7 +160,8 @@ func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -164,7 +176,8 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
@@ -189,7 +202,8 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent-id", map[string]interface{}{
"model_name": "gpt-4",
@@ -201,7 +215,8 @@ func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -226,7 +241,8 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -243,7 +259,8 @@ func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -301,7 +318,7 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model")
@@ -362,7 +379,7 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date")
@@ -431,7 +448,8 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
@@ -456,7 +474,8 @@ func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))

View File

@@ -0,0 +1,169 @@
package service
import (
"strings"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
"nex/backend/internal/repository"
)
type StatsBuffer struct {
counters sync.Map
flushInterval time.Duration
flushThreshold int
totalCount atomic.Int64
statsRepo repository.StatsRepository
logger *zap.Logger
stopCh chan struct{}
doneCh chan struct{}
}
type StatsBufferOption func(*StatsBuffer)
func WithFlushInterval(d time.Duration) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushInterval = d
}
}
func WithFlushThreshold(threshold int) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushThreshold = threshold
}
}
func NewStatsBuffer(
statsRepo repository.StatsRepository,
logger *zap.Logger,
opts ...StatsBufferOption,
) *StatsBuffer {
b := &StatsBuffer{
statsRepo: statsRepo,
logger: logger,
flushInterval: 5 * time.Second,
flushThreshold: 100,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
for _, opt := range opts {
opt(b)
}
return b
}
func (b *StatsBuffer) Increment(providerID, modelName string) {
today := time.Now().Format("2006-01-02")
key := providerID + "/" + modelName + "/" + today
var counter *int64
if v, ok := b.counters.Load(key); ok {
counter = v.(*int64)
} else {
val := int64(0)
counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded {
counter = actual.(*int64)
}
}
atomic.AddInt64(counter, 1)
if b.totalCount.Add(1) >= int64(b.flushThreshold) {
go b.flush()
}
}
func (b *StatsBuffer) Start() {
go func() {
ticker := time.NewTicker(b.flushInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
b.flush()
case <-b.stopCh:
b.flush()
close(b.doneCh)
return
}
}
}()
}
func (b *StatsBuffer) Stop() {
close(b.stopCh)
<-b.doneCh
}
func (b *StatsBuffer) flush() {
type statEntry struct {
providerID string
modelName string
date string
count int64
}
var entries []statEntry
b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string)
parts := strings.Split(keyStr, "/")
if len(parts) != 3 {
return true
}
counter := value.(*int64)
count := atomic.SwapInt64(counter, 0)
if count > 0 {
entries = append(entries, statEntry{
providerID: parts[0],
modelName: parts[1],
date: parts[2],
count: count,
})
}
return true
})
if len(entries) == 0 {
return
}
success := 0
for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil {
b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.Int64("count", entry.count),
zap.Error(err))
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok {
counter := v.(*int64)
atomic.AddInt64(counter, entry.count)
}
} else {
success++
}
}
b.totalCount.Store(0)
b.logger.Debug("统计刷新完成",
zap.Int("total", len(entries)),
zap.Int("success", success),
zap.Int("failed", len(entries)-success))
}

View File

@@ -0,0 +1,251 @@
package service
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockStatsRepo struct {
records []struct {
providerID string
modelName string
date time.Time
delta int
}
fail bool
mu sync.Mutex
}
func (m *mockStatsRepo) Record(providerID, modelName string) error {
return nil
}
func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.fail {
return errors.New("db error")
}
m.records = append(m.records, struct {
providerID string
modelName string
date time.Time
delta int
}{providerID, modelName, date, delta})
return nil
}
func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return nil, nil
}
func TestStatsBuffer_Increment(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-3.5")
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count += atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(3), count)
}
func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(100), count)
}
func TestStatsBuffer_LoadOrStore(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var counterCount int
buffer.counters.Range(func(key, value interface{}) bool {
counterCount++
return true
})
assert.Equal(t, 1, counterCount)
}
func TestStatsBuffer_FlushByInterval(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(100*time.Millisecond))
buffer.Start()
defer buffer.Stop()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
time.Sleep(200 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_FlushByThreshold(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushThreshold(10))
buffer.Start()
defer buffer.Stop()
for i := 0; i < 10; i++ {
buffer.Increment("openai", "gpt-4")
}
time.Sleep(50 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_SwapInt64(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
beforeCount = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(2), beforeCount)
buffer.flush()
var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
afterCount = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(0), afterCount)
}
func TestStatsBuffer_FailRetry(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{fail: true}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.flush()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(2), count)
}
func TestStatsBuffer_Stop(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(10*time.Second))
buffer.Start()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
start := time.Now()
buffer.Stop()
elapsed := time.Since(start)
assert.Less(t, elapsed, 1*time.Second)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(50*time.Millisecond))
buffer.Start()
defer buffer.Stop()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
time.Sleep(100 * time.Millisecond)
statsRepo.mu.Lock()
totalDelta := 0
for _, r := range statsRepo.records {
totalDelta += r.delta
}
statsRepo.mu.Unlock()
assert.Equal(t, 100, totalDelta)
}

View File

@@ -10,14 +10,16 @@ import (
type statsService struct {
statsRepo repository.StatsRepository
buffer *StatsBuffer
}
func NewStatsService(statsRepo repository.StatsRepository) StatsService {
return &statsService{statsRepo: statsRepo}
func NewStatsService(statsRepo repository.StatsRepository, buffer *StatsBuffer) StatsService {
return &statsService{statsRepo: statsRepo, buffer: buffer}
}
func (s *statsService) Record(providerID, modelName string) error {
return s.statsRepo.Record(providerID, modelName)
s.buffer.Increment(providerID, modelName)
return nil
}
func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/conversion"
@@ -54,10 +55,14 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()

View File

@@ -15,6 +15,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
@@ -48,10 +49,14 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
registry := conversion.NewMemoryRegistry()
require.NoError(t, registry.Register(openaiConv.NewAdapter()))

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/domain"
@@ -30,10 +31,14 @@ func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
_ = service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
_ = service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)

View File

@@ -41,6 +41,20 @@ func (m *MockStatsRepository) EXPECT() *MockStatsRepositoryMockRecorder {
return m.recorder
}
// BatchUpdate mocks base method.
func (m *MockStatsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchUpdate", providerID, modelName, date, delta)
ret0, _ := ret[0].(error)
return ret0
}
// BatchUpdate indicates an expected call of BatchUpdate.
func (mr *MockStatsRepositoryMockRecorder) BatchUpdate(providerID, modelName, date, delta any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdate", reflect.TypeOf((*MockStatsRepository)(nil).BatchUpdate), providerID, modelName, date, delta)
}
// Query mocks base method.
func (m *MockStatsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
m.ctrl.T.Helper()

View File

@@ -0,0 +1,136 @@
# Routing Cache
## Purpose
TBD - 为 Provider 和 Model 提供内存缓存,优化路由查询性能。
## Requirements
### Requirement: RoutingCache 缓存 Provider 和 Model
系统 SHALL 为 Provider 和 Model 提供内存缓存,使用 sync.Map 作为缓存数据结构。
#### Scenario: 缓存数据结构
- **WHEN** 创建 RoutingCache
- **THEN** SHALL 使用 sync.Map 存储 Providerkey = providerID
- **THEN** SHALL 使用 sync.Map 存储 Modelkey = providerID/modelName
### Requirement: 缓存查询优先
RoutingCache SHALL 优先从缓存查询,缓存未命中时查询数据库并更新缓存。
#### Scenario: 缓存命中
- **WHEN** 查询 Provider 或 Model
- **THEN** SHALL 先从缓存查询
- **THEN** 如果缓存命中SHALL 直接返回缓存值
- **THEN** SHALL 不查询数据库
#### Scenario: 缓存未命中
- **WHEN** 查询 Provider 或 Model
- **THEN** SHALL 先从缓存查询
- **THEN** 如果缓存未命中SHALL 查询数据库
- **THEN** SHALL 将查询结果存入缓存
- **THEN** SHALL 返回查询结果
#### Scenario: 双重检查防止竞态
- **WHEN** 并发查询同一个 Provider 或 Model
- **THEN** SHALL 在查询数据库后再次检查缓存
- **THEN** 如果已有其他 goroutine 存入缓存SHALL 使用缓存值
- **THEN** SHALL 防止存入旧值
### Requirement: 缓存更新
RoutingCache SHALL 支持 Create 操作后更新缓存。
#### Scenario: Create Provider 后更新缓存
- **WHEN** 创建 Provider 成功
- **THEN** SHALL 调用 RoutingCache.SetProvider
- **THEN** SHALL 将新 Provider 存入缓存
#### Scenario: Create Model 后更新缓存
- **WHEN** 创建 Model 成功
- **THEN** SHALL 调用 RoutingCache.SetModel
- **THEN** SHALL 将新 Model 存入缓存
### Requirement: 缓存失效
RoutingCache SHALL 支持 Update/Delete 操作后清除缓存。
#### Scenario: Update Provider 后清除缓存
- **WHEN** 更新 Provider 成功
- **THEN** SHALL 调用 RoutingCache.InvalidateProvider
- **THEN** SHALL 清除该 Provider 的缓存
- **THEN** SHALL 级联清除该 Provider 的所有 Model 缓存
#### Scenario: Delete Provider 后清除缓存
- **WHEN** 删除 Provider 成功
- **THEN** SHALL 调用 RoutingCache.InvalidateProvider
- **THEN** SHALL 清除该 Provider 的缓存
- **THEN** SHALL 级联清除该 Provider 的所有 Model 缓存
#### Scenario: Update Model 后清除缓存
- **WHEN** 更新 Model 成功
- **THEN** SHALL 调用 RoutingCache.InvalidateModel
- **THEN** SHALL 清除旧位置的 Model 缓存
- **THEN** 如果 provider_id 或 model_name 变化SHALL 也清除新位置的缓存
#### Scenario: Delete Model 后清除缓存
- **WHEN** 删除 Model 成功
- **THEN** SHALL 调用 RoutingCache.InvalidateModel
- **THEN** SHALL 清除该 Model 的缓存
### Requirement: 缓存预热
RoutingCache SHALL 支持启动时预热缓存。
#### Scenario: 预热成功
- **WHEN** 服务启动时
- **THEN** SHALL 调用 RoutingCache.Preload
- **THEN** SHALL 从数据库加载所有 Provider 到缓存
- **THEN** SHALL 从数据库加载所有 Model 到缓存
- **THEN** SHALL 记录预热完成的日志
#### Scenario: 预热失败
- **WHEN** 预热失败时
- **THEN** SHALL 记录警告日志
- **THEN** SHALL 继续启动服务
- **THEN** SHALL 使用懒加载(首次查询时加载)
### Requirement: RoutingService 使用缓存
RoutingService SHALL 使用 RoutingCache 进行路由查询。
#### Scenario: RouteByModelName 使用缓存
- **WHEN** 调用 RoutingService.RouteByModelName
- **THEN** SHALL 调用 RoutingCache.GetModel 获取 Model
- **THEN** SHALL 调用 RoutingCache.GetProvider 获取 Provider
- **THEN** SHALL 不直接调用 Repository
### Requirement: 并发安全
RoutingCache SHALL 支持并发访问。
#### Scenario: 并发查询
- **WHEN** 多个 goroutine 并发查询缓存
- **THEN** SHALL 无竞态条件
- **THEN** SHALL 无 panic
#### Scenario: 并发查询和失效
- **WHEN** 并发查询和失效缓存
- **THEN** SHALL 无竞态条件
- **THEN** SHALL 保证一致性

View File

@@ -0,0 +1,156 @@
# Stats Buffer
## Purpose
TBD - 为统计数据提供内存缓冲,优化写入性能。
## Requirements
### Requirement: StatsBuffer 内存缓冲
系统 SHALL 为统计数据提供内存缓冲,使用 sync.Map + atomic.Int64 进行计数。
#### Scenario: 缓冲数据结构
- **WHEN** 创建 StatsBuffer
- **THEN** SHALL 使用 sync.Map 存储计数器key = providerID/modelName/date
- **THEN** SHALL 使用 atomic.Int64 进行计数
- **THEN** SHALL 支持配置刷新间隔和阈值
### Requirement: 原子计数
StatsBuffer SHALL 使用原子操作进行计数。
#### Scenario: Increment 计数
- **WHEN** 调用 StatsBuffer.Increment
- **THEN** SHALL 使用 atomic.AddInt64 增加计数
- **THEN** SHALL 无锁操作
- **THEN** SHALL 线程安全
#### Scenario: Increment 创建计数器
- **WHEN** 调用 StatsBuffer.Increment 且计数器不存在
- **THEN** SHALL 使用 sync.Map.LoadOrStore 创建计数器
- **THEN** SHALL 初始化计数器为 0
- **THEN** SHALL 原子增加计数
#### Scenario: 并发计数
- **WHEN** 多个 goroutine 并发 Increment
- **THEN** SHALL 计数准确
- **THEN** SHALL 无竞态条件
### Requirement: 定时刷新
StatsBuffer SHALL 支持定时刷新到数据库。
#### Scenario: 定时刷新触发
- **WHEN** 后台刷新协程运行
- **THEN** SHALL 每隔 flushInterval 触发刷新
- **THEN** SHALL 调用 StatsRepository.BatchUpdate 写入数据库
#### Scenario: 刷新间隔配置
- **WHEN** 创建 StatsBuffer
- **THEN** 默认 flushInterval 为 5 秒
- **THEN** 可通过 WithFlushInterval 选项配置
### Requirement: 阈值触发刷新
StatsBuffer SHALL 支持累计阈值触发刷新。
#### Scenario: 阈值触发
- **WHEN** 累计计数达到 flushThreshold
- **THEN** SHALL 异步触发刷新
- **THEN** SHALL 不阻塞请求
#### Scenario: 阈值配置
- **WHEN** 创建 StatsBuffer
- **THEN** 默认 flushThreshold 为 100
- **THEN** 可通过 WithFlushThreshold 选项配置
### Requirement: 批量写入数据库
StatsBuffer SHALL 批量写入统计数据到数据库。
#### Scenario: 批量写入
- **WHEN** 刷新触发
- **THEN** SHALL 遍历所有计数器
- **THEN** SHALL 使用 atomic.SwapInt64(counter, 0) 获取并清零计数器
- **THEN** SHALL 调用 StatsRepository.BatchUpdate 批量写入
- **THEN** SHALL 重置 totalCount 为 0
#### Scenario: SwapInt64 清零计数器
- **WHEN** flush 收集计数器
- **THEN** SHALL 使用 SwapInt64 原子操作获取当前计数并清零
- **THEN** SHALL 保证计数不丢失(新计数会累加到已清零的计数器)
- **THEN** SHALL 不阻塞后续 Increment 操作
#### Scenario: 写入失败保留计数器
- **WHEN** BatchUpdate 失败
- **THEN** SHALL 将计数加回计数器(使用 atomic.AddInt64
- **THEN** SHALL 记录错误日志
- **THEN** SHALL 继续处理其他条目
### Requirement: 优雅关闭
StatsBuffer SHALL 支持优雅关闭,确保最后的统计写入数据库。
#### Scenario: Stop 等待刷新完成
- **WHEN** 调用 StatsBuffer.Stop
- **THEN** SHALL 停止后台刷新协程
- **THEN** SHALL 执行最后一次刷新
- **THEN** SHALL 等待刷新完成
- **THEN** SHALL 保证统计数据不丢失
#### Scenario: 无超时
- **WHEN** Stop 等待刷新
- **THEN** SHALL 无超时限制
- **THEN** SHALL 等待刷新完成
### Requirement: StatsService 使用缓冲
StatsService SHALL 使用 StatsBuffer 进行统计记录。
#### Scenario: Record 使用缓冲
- **WHEN** 调用 StatsService.Record
- **THEN** SHALL 调用 StatsBuffer.Increment
- **THEN** SHALL 不直接调用 StatsRepository.Record
- **THEN** SHALL 立即返回,不阻塞
### Requirement: StatsRepository 扩展
StatsRepository SHALL 新增 BatchUpdate 方法。
#### Scenario: BatchUpdate 方法
- **WHEN** 调用 StatsRepository.BatchUpdate
- **THEN** SHALL 使用事务更新或创建统计记录
- **THEN** SHALL 使用 request_count + delta 更新
- **THEN** SHALL 支持批量增量更新
### Requirement: 并发安全
StatsBuffer SHALL 支持并发访问。
#### Scenario: 并发 Increment 和 flush
- **WHEN** 并发 Increment 和 flush
- **THEN** SHALL 无竞态条件
- **THEN** SHALL 计数准确(可能延迟到下次 flush
#### Scenario: flush 期间 Increment
- **WHEN** flush 正在执行
- **THEN** 新的 Increment SHALL 继续计数
- **THEN** SHALL 不会丢失计数

View File

@@ -13,14 +13,15 @@
#### Scenario: 记录成功请求
- **WHEN** 请求成功转发到供应商
- **THEN** 网关 SHALL 增加该供应商和模型的请求计数
- **THEN** 网关 SHALL 记录当前日期的统计
- **THEN** 网关 SHALL 通过 StatsBuffer 增加该供应商和模型的请求计数
- **THEN** 网关 SHALL 异步批量写入数据库(定时或阈值触发)
- **THEN** 网关 SHALL 不阻塞响应
#### Scenario: 记录流式请求
- **WHEN** 流式请求成功完成
- **THEN** 网关 SHALL 增加该供应商和模型的请求计数
- **THEN** 网关 SHALL 在流结束后记录统计
- **THEN** 网关 SHALL 通过 StatsBuffer 增加该供应商和模型的请求计数
- **THEN** 网关 SHALL 在流结束后异步记录统计
### Requirement: 使用统计记录统一模型标识
@@ -90,10 +91,9 @@
#### Scenario: 并发请求
- **WHEN** 同时处理多个并发请求
- **THEN** 网关 SHALL 正确为每个请求增加请求计数
- **THEN** 网关 SHALL 使用原子操作正确增加每个请求的计数
- **THEN** 不 SHALL 因并发写入而丢失统计
**变更说明:** 并发控制在 StatsRepository 中通过数据库事务实现。API 接口保持不变。
- **THEN** SHALL 使用 StatsBuffer 的内存计数器
### Requirement: 使用 service 层处理业务逻辑
@@ -121,6 +121,13 @@ Service SHALL 通过 StatsRepository 访问数据。
- **THEN** SHALL 调用对应的 StatsRepository 方法
- **THEN** SHALL 使用 domain.UsageStats 类型
#### Scenario: 批量更新统计
- **WHEN** StatsBuffer 刷新统计
- **THEN** SHALL 调用 StatsRepository.BatchUpdate
- **THEN** SHALL 使用事务更新或创建统计记录
- **THEN** SHALL 支持增量更新request_count + delta
#### Scenario: 事务处理
- **WHEN** 记录统计
@@ -136,3 +143,36 @@ Service SHALL 通过 StatsRepository 访问数据。
- **WHEN** 查询统计
- **THEN** SHALL 使用 (provider_id, model_name, date) 复合索引
- **THEN** SHALL 优化查询性能
### Requirement: 统计数据可接受少量丢失
统计记录方式改为内存缓冲,可接受少量丢失。
#### Scenario: 进程崩溃丢失统计
- **WHEN** 进程崩溃
- **THEN** MAY 丢失最近 flushInterval 内的统计
- **THEN** 统计数据用于监控,可接受少量丢失
#### Scenario: 优雅关闭保证不丢失
- **WHEN** 服务优雅关闭
- **THEN** SHALL 调用 StatsBuffer.Stop
- **THEN** SHALL 等待最后一次刷新完成
- **THEN** SHALL 保证统计数据不丢失
### Requirement: StatsRepository 支持批量更新
StatsRepository SHALL 新增 BatchUpdate 方法支持批量增量更新。
#### Scenario: BatchUpdate 更新现有记录
- **WHEN** 调用 BatchUpdate 且当日记录存在
- **THEN** SHALL 使用事务更新 request_count = request_count + delta
- **THEN** SHALL 不创建新记录
#### Scenario: BatchUpdate 创建新记录
- **WHEN** 调用 BatchUpdate 且当日记录不存在
- **THEN** SHALL 创建新记录request_count = delta
- **THEN** SHALL 使用事务保证原子性