1
0

feat(cache): 实现 RoutingCache 和 StatsBuffer 优化数据库写入

- 新增 RoutingCache 组件,使用 sync.Map 缓存 Provider 和 Model
- 新增 StatsBuffer 组件,使用 sync.Map + atomic.Int64 缓冲统计数据
- 扩展 StatsRepository.BatchUpdate 支持批量增量更新
- 改造 RoutingService/StatsService/ProviderService/ModelService 集成缓存
- 更新 usage-statistics spec,新增 routing-cache 和 stats-buffer spec
- 新增单元测试覆盖缓存命中/失效/并发场景
This commit is contained in:
2026-04-22 19:24:36 +08:00
parent f5e45d032e
commit df253559a5
20 changed files with 1377 additions and 91 deletions

View File

@@ -67,13 +67,25 @@ func main() {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
// 5. 初始化 service 层
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
// 5. 初始化缓存
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
if err := routingCache.Preload(); err != nil {
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
}
// 6. 创建 ConversionEngine
// 6. 初始化统计缓冲
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
service.WithFlushInterval(5*time.Second),
service.WithFlushThreshold(100))
statsBuffer.Start()
// 7. 初始化 service 层
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 8. 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
@@ -83,16 +95,16 @@ func main() {
}
engine := conversion.NewConversionEngine(registry, zapLogger)
// 7. 初始化 provider client
// 9. 初始化 provider client
providerClient := provider.NewClient()
// 8. 初始化 handler 层
// 10. 初始化 handler 层
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
// 9. 创建 Gin 引擎
// 11. 创建 Gin 引擎
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -103,7 +115,7 @@ func main() {
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
// 10. 启动服务器
// 12. 启动服务器
srv := &http.Server{
Addr: formatAddr(cfg.Server.Port),
Handler: r,
@@ -131,6 +143,8 @@ func main() {
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
}
statsBuffer.Stop()
zapLogger.Info("服务器已关闭")
}

View File

@@ -11,5 +11,6 @@ import (
// StatsRepository 统计数据仓库接口
type StatsRepository interface {
Record(providerID, modelName string) error
BatchUpdate(providerID, modelName string, date time.Time, delta int) error
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
}

View File

@@ -43,6 +43,28 @@ func (r *statsRepository) Record(providerID, modelName string) error {
})
}
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, date).First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return tx.Create(&config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}).Error
} else if err != nil {
return err
}
return tx.Model(&stats).
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
})
}
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
var stats []config.UsageStats
query := r.db.Model(&config.UsageStats{})

View File

@@ -11,27 +11,30 @@ import (
type modelService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache}
}
func (s *modelService) Create(model *domain.Model) error {
// 校验供应商存在
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
return appErrors.ErrProviderNotFound
}
// 联合唯一校验:同一供应商下 model_name 不重复
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
return err
}
// 自动生成 UUID 作为 id
model.ID = uuid.New().String()
model.Enabled = true
return s.modelRepo.Create(model)
err := s.modelRepo.Create(model)
if err != nil {
return err
}
s.cache.SetModel(model)
return nil
}
func (s *modelService) Get(id string) (*domain.Model, error) {
@@ -47,20 +50,17 @@ func (s *modelService) ListEnabled() ([]domain.Model, error) {
}
func (s *modelService) Update(id string, updates map[string]interface{}) error {
// 获取当前模型
current, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
// 如果更新 provider_id校验新供应商存在
if providerID, ok := updates["provider_id"].(string); ok {
if _, err := s.providerRepo.GetByID(providerID); err != nil {
return appErrors.ErrProviderNotFound
}
}
// 确定更新后的 provider_id 和 model_name
newProviderID := current.ProviderID
if v, ok := updates["provider_id"].(string); ok {
newProviderID = v
@@ -70,18 +70,37 @@ func (s *modelService) Update(id string, updates map[string]interface{}) error {
newModelName = v
}
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
if newProviderID != current.ProviderID || newModelName != current.ModelName {
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
return err
}
}
return s.modelRepo.Update(id, updates)
err = s.modelRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateModel(current.ProviderID, current.ModelName)
if newProviderID != current.ProviderID || newModelName != current.ModelName {
s.cache.InvalidateModel(newProviderID, newModelName)
}
return nil
}
func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id)
model, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
err = s.modelRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateModel(model.ProviderID, model.ModelName)
return nil
}
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复

View File

@@ -13,23 +13,27 @@ import (
type providerService struct {
providerRepo repository.ProviderRepository
modelRepo repository.ModelRepository
cache *RoutingCache
}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache}
}
func (s *providerService) Create(provider *domain.Provider) error {
// 校验 provider_id 字符集
if err := modelid.ValidateProviderID(provider.ID); err != nil {
return appErrors.ErrInvalidProviderID
}
provider.Enabled = true
err := s.providerRepo.Create(provider)
if err != nil && isUniqueConstraintError(err) {
return appErrors.ErrConflict
if err != nil {
if isUniqueConstraintError(err) {
return appErrors.ErrConflict
}
return err
}
return err
s.cache.SetProvider(provider)
return nil
}
func (s *providerService) Get(id string) (*domain.Provider, error) {
@@ -44,11 +48,21 @@ func (s *providerService) Update(id string, updates map[string]interface{}) erro
if _, ok := updates["id"]; ok {
return appErrors.ErrImmutableField
}
return s.providerRepo.Update(id, updates)
err := s.providerRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id)
err := s.providerRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)

View File

@@ -0,0 +1,134 @@
package service
import (
"strings"
"sync"
"go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type RoutingCache struct {
providers sync.Map
models sync.Map
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
logger *zap.Logger
}
func NewRoutingCache(
modelRepo repository.ModelRepository,
providerRepo repository.ProviderRepository,
logger *zap.Logger,
) *RoutingCache {
return &RoutingCache{
modelRepo: modelRepo,
providerRepo: providerRepo,
logger: logger,
}
}
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
}
provider, err := c.providerRepo.GetByID(id)
if err != nil {
return nil, err
}
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
}
c.providers.Store(id, provider)
return provider, nil
}
func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
}
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil, err
}
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
}
c.models.Store(key, model)
return model, nil
}
func (c *RoutingCache) SetProvider(provider *domain.Provider) {
c.providers.Store(provider.ID, provider)
}
func (c *RoutingCache) SetModel(model *domain.Model) {
key := model.ProviderID + "/" + model.ModelName
c.models.Store(key, model)
}
func (c *RoutingCache) InvalidateProvider(id string) {
c.providers.Delete(id)
c.invalidateModelsByProvider(id)
c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id))
}
func (c *RoutingCache) InvalidateModel(providerID, modelName string) {
key := providerID + "/" + modelName
c.models.Delete(key)
c.logger.Debug("Model 缓存失效",
zap.String("provider_id", providerID),
zap.String("model_name", modelName))
}
func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/"
count := 0
c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) {
c.models.Delete(key)
count++
}
return true
})
if count > 0 {
c.logger.Debug("清除 Provider 相关 Model 缓存",
zap.String("provider_id", providerID),
zap.Int("count", count))
}
}
func (c *RoutingCache) Preload() error {
providers, err := c.providerRepo.List()
if err != nil {
return err
}
for i := range providers {
c.providers.Store(providers[i].ID, &providers[i])
}
models, err := c.modelRepo.List("")
if err != nil {
return err
}
for i := range models {
key := models[i].ProviderID + "/" + models[i].ModelName
c.models.Store(key, &models[i])
}
c.logger.Info("缓存预热完成",
zap.Int("providers", len(providers)),
zap.Int("models", len(models)))
return nil
}

View File

@@ -0,0 +1,273 @@
package service
import (
"errors"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockModelRepo struct {
models map[string]*domain.Model
}
func (m *mockModelRepo) Create(model *domain.Model) error {
return nil
}
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if model, ok := m.models[key]; ok {
return model, nil
}
return nil, errors.New("not found")
}
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
var result []domain.Model
for _, model := range m.models {
if providerID == "" || model.ProviderID == providerID {
result = append(result, *model)
}
}
return result, nil
}
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockModelRepo) Delete(id string) error {
return nil
}
type mockProviderRepo struct {
providers map[string]*domain.Provider
}
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
return nil
}
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
if provider, ok := m.providers[id]; ok {
return provider, nil
}
return nil, errors.New("not found")
}
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
var result []domain.Provider
for _, provider := range m.providers {
result = append(result, *provider)
}
return result, nil
}
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockProviderRepo) Delete(id string) error {
return nil
}
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
provider, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
provider2, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, provider, provider2)
}
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetProvider("notexist")
assert.Error(t, err)
}
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
model, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
model2, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, model, model2)
}
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "notexist")
assert.Error(t, err)
}
func TestRoutingCache_DoubleCheck(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := cache.GetProvider("openai")
assert.NoError(t, err)
}()
}
wg.Wait()
}
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
_, err = cache.GetModel("openai", "gpt-3.5")
require.NoError(t, err)
_, err = cache.GetModel("anthropic", "claude")
require.NoError(t, err)
cache.InvalidateProvider("openai")
var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" {
anthropicCount++
}
return true
})
assert.Equal(t, 0, openaiCount)
assert.Equal(t, 1, anthropicCount)
}
func TestRoutingCache_InvalidateModel(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
cache.InvalidateModel("openai", "gpt-4")
var count int
cache.models.Range(func(key, value interface{}) bool {
count++
return true
})
assert.Equal(t, 0, count)
}
func TestRoutingCache_Preload_Success(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
err := cache.Preload()
require.NoError(t, err)
var providerCount, modelCount int
cache.providers.Range(func(key, value interface{}) bool {
providerCount++
return true
})
cache.models.Range(func(key, value interface{}) bool {
modelCount++
return true
})
assert.Equal(t, 1, providerCount)
assert.Equal(t, 1, modelCount)
}
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetProvider("openai")
_, _ = cache.GetModel("openai", "gpt-4")
cache.InvalidateProvider("openai")
cache.InvalidateModel("openai", "gpt-4")
}()
}
wg.Wait()
}

View File

@@ -4,20 +4,18 @@ import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type routingService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewRoutingService(cache *RoutingCache) RoutingService {
return &routingService{cache: cache}
}
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
model, err := s.cache.GetModel(providerID, modelName)
if err != nil {
return nil, appErrors.ErrModelNotFound
}
@@ -26,7 +24,7 @@ func (s *routingService) RouteByModelName(providerID, modelName string) (*domain
return nil, appErrors.ErrModelDisabled
}
provider, err := s.providerRepo.GetByID(model.ProviderID)
provider, err := s.cache.GetProvider(model.ProviderID)
if err != nil {
return nil, appErrors.ErrProviderNotFound
}

View File

@@ -14,7 +14,8 @@ func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
@@ -30,7 +31,8 @@ func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err)
@@ -40,7 +42,8 @@ func TestModelService_Get(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -55,7 +58,8 @@ func TestModelService_Update(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -73,7 +77,8 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -87,7 +92,8 @@ func TestModelService_Delete(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
@@ -104,7 +110,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Delete("nonexistent")
assert.Error(t, err)
@@ -112,7 +119,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
func TestStatsService_Aggregate_Default(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, nil)
svc := NewStatsService(statsRepo, buffer)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
@@ -133,7 +141,8 @@ func TestModelService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
assert.Error(t, err)

View File

@@ -8,6 +8,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
testHelpers "nex/backend/tests"
@@ -22,13 +23,21 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
return testHelpers.SetupTestDB(t)
}
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
t.Helper()
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
}
// ============ RoutingService - RouteByModelName 测试 ============
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
@@ -41,9 +50,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
_, err := svc.RouteByModelName("openai", "nonexistent-model")
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
@@ -53,7 +61,8 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
@@ -67,7 +76,8 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
@@ -83,7 +93,8 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -91,12 +102,10 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
err := svc.Create(model)
require.NoError(t, err)
// 验证返回的 model 拥有有效的 UUID
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
// 通过 Get 验证持久化
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
@@ -107,7 +116,8 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -115,7 +125,6 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
err := svc.Create(model1)
require.NoError(t, err)
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
@@ -125,7 +134,8 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
@@ -138,7 +148,8 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -149,7 +160,8 @@ func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -164,7 +176,8 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
@@ -189,7 +202,8 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent-id", map[string]interface{}{
"model_name": "gpt-4",
@@ -201,7 +215,8 @@ func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
@@ -226,7 +241,8 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -243,7 +259,8 @@ func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -301,7 +318,7 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model")
@@ -362,7 +379,7 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date")
@@ -431,7 +448,8 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
@@ -456,7 +474,8 @@ func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))

View File

@@ -0,0 +1,169 @@
package service
import (
"strings"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
"nex/backend/internal/repository"
)
type StatsBuffer struct {
counters sync.Map
flushInterval time.Duration
flushThreshold int
totalCount atomic.Int64
statsRepo repository.StatsRepository
logger *zap.Logger
stopCh chan struct{}
doneCh chan struct{}
}
type StatsBufferOption func(*StatsBuffer)
func WithFlushInterval(d time.Duration) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushInterval = d
}
}
func WithFlushThreshold(threshold int) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushThreshold = threshold
}
}
func NewStatsBuffer(
statsRepo repository.StatsRepository,
logger *zap.Logger,
opts ...StatsBufferOption,
) *StatsBuffer {
b := &StatsBuffer{
statsRepo: statsRepo,
logger: logger,
flushInterval: 5 * time.Second,
flushThreshold: 100,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
for _, opt := range opts {
opt(b)
}
return b
}
func (b *StatsBuffer) Increment(providerID, modelName string) {
today := time.Now().Format("2006-01-02")
key := providerID + "/" + modelName + "/" + today
var counter *int64
if v, ok := b.counters.Load(key); ok {
counter = v.(*int64)
} else {
val := int64(0)
counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded {
counter = actual.(*int64)
}
}
atomic.AddInt64(counter, 1)
if b.totalCount.Add(1) >= int64(b.flushThreshold) {
go b.flush()
}
}
func (b *StatsBuffer) Start() {
go func() {
ticker := time.NewTicker(b.flushInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
b.flush()
case <-b.stopCh:
b.flush()
close(b.doneCh)
return
}
}
}()
}
func (b *StatsBuffer) Stop() {
close(b.stopCh)
<-b.doneCh
}
func (b *StatsBuffer) flush() {
type statEntry struct {
providerID string
modelName string
date string
count int64
}
var entries []statEntry
b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string)
parts := strings.Split(keyStr, "/")
if len(parts) != 3 {
return true
}
counter := value.(*int64)
count := atomic.SwapInt64(counter, 0)
if count > 0 {
entries = append(entries, statEntry{
providerID: parts[0],
modelName: parts[1],
date: parts[2],
count: count,
})
}
return true
})
if len(entries) == 0 {
return
}
success := 0
for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil {
b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.Int64("count", entry.count),
zap.Error(err))
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok {
counter := v.(*int64)
atomic.AddInt64(counter, entry.count)
}
} else {
success++
}
}
b.totalCount.Store(0)
b.logger.Debug("统计刷新完成",
zap.Int("total", len(entries)),
zap.Int("success", success),
zap.Int("failed", len(entries)-success))
}

View File

@@ -0,0 +1,251 @@
package service
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockStatsRepo struct {
records []struct {
providerID string
modelName string
date time.Time
delta int
}
fail bool
mu sync.Mutex
}
func (m *mockStatsRepo) Record(providerID, modelName string) error {
return nil
}
func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.fail {
return errors.New("db error")
}
m.records = append(m.records, struct {
providerID string
modelName string
date time.Time
delta int
}{providerID, modelName, date, delta})
return nil
}
func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return nil, nil
}
func TestStatsBuffer_Increment(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-3.5")
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count += atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(3), count)
}
func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(100), count)
}
func TestStatsBuffer_LoadOrStore(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var counterCount int
buffer.counters.Range(func(key, value interface{}) bool {
counterCount++
return true
})
assert.Equal(t, 1, counterCount)
}
func TestStatsBuffer_FlushByInterval(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(100*time.Millisecond))
buffer.Start()
defer buffer.Stop()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
time.Sleep(200 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_FlushByThreshold(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushThreshold(10))
buffer.Start()
defer buffer.Stop()
for i := 0; i < 10; i++ {
buffer.Increment("openai", "gpt-4")
}
time.Sleep(50 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_SwapInt64(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
beforeCount = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(2), beforeCount)
buffer.flush()
var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
afterCount = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(0), afterCount)
}
func TestStatsBuffer_FailRetry(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{fail: true}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.flush()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(2), count)
}
func TestStatsBuffer_Stop(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(10*time.Second))
buffer.Start()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
start := time.Now()
buffer.Stop()
elapsed := time.Since(start)
assert.Less(t, elapsed, 1*time.Second)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(50*time.Millisecond))
buffer.Start()
defer buffer.Stop()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
time.Sleep(100 * time.Millisecond)
statsRepo.mu.Lock()
totalDelta := 0
for _, r := range statsRepo.records {
totalDelta += r.delta
}
statsRepo.mu.Unlock()
assert.Equal(t, 100, totalDelta)
}

View File

@@ -10,14 +10,16 @@ import (
type statsService struct {
statsRepo repository.StatsRepository
buffer *StatsBuffer
}
func NewStatsService(statsRepo repository.StatsRepository) StatsService {
return &statsService{statsRepo: statsRepo}
func NewStatsService(statsRepo repository.StatsRepository, buffer *StatsBuffer) StatsService {
return &statsService{statsRepo: statsRepo, buffer: buffer}
}
func (s *statsService) Record(providerID, modelName string) error {
return s.statsRepo.Record(providerID, modelName)
s.buffer.Increment(providerID, modelName)
return nil
}
func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/conversion"
@@ -54,10 +55,14 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()

View File

@@ -15,6 +15,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
@@ -48,10 +49,14 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
registry := conversion.NewMemoryRegistry()
require.NoError(t, registry.Register(openaiConv.NewAdapter()))

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/domain"
@@ -30,10 +31,14 @@ func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
_ = service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
_ = service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)

View File

@@ -41,6 +41,20 @@ func (m *MockStatsRepository) EXPECT() *MockStatsRepositoryMockRecorder {
return m.recorder
}
// BatchUpdate mocks base method.
func (m *MockStatsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchUpdate", providerID, modelName, date, delta)
ret0, _ := ret[0].(error)
return ret0
}
// BatchUpdate indicates an expected call of BatchUpdate.
func (mr *MockStatsRepositoryMockRecorder) BatchUpdate(providerID, modelName, date, delta any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdate", reflect.TypeOf((*MockStatsRepository)(nil).BatchUpdate), providerID, modelName, date, delta)
}
// Query mocks base method.
func (m *MockStatsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
m.ctrl.T.Helper()