diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 57bfe6a..a604401 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -67,13 +67,25 @@ func main() { modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - // 5. 初始化 service 层 - providerService := service.NewProviderService(providerRepo, modelRepo) - modelService := service.NewModelService(modelRepo, providerRepo) - routingService := service.NewRoutingService(modelRepo, providerRepo) - statsService := service.NewStatsService(statsRepo) + // 5. 初始化缓存 + routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger) + if err := routingCache.Preload(); err != nil { + zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err)) + } - // 6. 创建 ConversionEngine + // 6. 初始化统计缓冲 + statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger, + service.WithFlushInterval(5*time.Second), + service.WithFlushThreshold(100)) + statsBuffer.Start() + + // 7. 初始化 service 层 + providerService := service.NewProviderService(providerRepo, modelRepo, routingCache) + modelService := service.NewModelService(modelRepo, providerRepo, routingCache) + routingService := service.NewRoutingService(routingCache) + statsService := service.NewStatsService(statsRepo, statsBuffer) + + // 8. 创建 ConversionEngine registry := conversion.NewMemoryRegistry() if err := registry.Register(openai.NewAdapter()); err != nil { zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error())) @@ -83,16 +95,16 @@ func main() { } engine := conversion.NewConversionEngine(registry, zapLogger) - // 7. 初始化 provider client + // 9. 初始化 provider client providerClient := provider.NewClient() - // 8. 初始化 handler 层 + // 10. 初始化 handler 层 proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) providerHandler := handler.NewProviderHandler(providerService) modelHandler := handler.NewModelHandler(modelService) statsHandler := handler.NewStatsHandler(statsService) - // 9. 创建 Gin 引擎 + // 11. 创建 Gin 引擎 gin.SetMode(gin.ReleaseMode) r := gin.New() @@ -103,7 +115,7 @@ func main() { setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler) - // 10. 启动服务器 + // 12. 启动服务器 srv := &http.Server{ Addr: formatAddr(cfg.Server.Port), Handler: r, @@ -131,6 +143,8 @@ func main() { zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error())) } + statsBuffer.Stop() + zapLogger.Info("服务器已关闭") } diff --git a/backend/internal/repository/stats_repo.go b/backend/internal/repository/stats_repo.go index fb358bd..ef9053d 100644 --- a/backend/internal/repository/stats_repo.go +++ b/backend/internal/repository/stats_repo.go @@ -11,5 +11,6 @@ import ( // StatsRepository 统计数据仓库接口 type StatsRepository interface { Record(providerID, modelName string) error + BatchUpdate(providerID, modelName string, date time.Time, delta int) error Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) } diff --git a/backend/internal/repository/stats_repo_impl.go b/backend/internal/repository/stats_repo_impl.go index dd2ef62..7692b33 100644 --- a/backend/internal/repository/stats_repo_impl.go +++ b/backend/internal/repository/stats_repo_impl.go @@ -43,6 +43,28 @@ func (r *statsRepository) Record(providerID, modelName string) error { }) } +func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error { + return r.db.Transaction(func(tx *gorm.DB) error { + var stats config.UsageStats + err := tx.Where("provider_id = ? AND model_name = ? AND date = ?", + providerID, modelName, date).First(&stats).Error + + if errors.Is(err, gorm.ErrRecordNotFound) { + return tx.Create(&config.UsageStats{ + ProviderID: providerID, + ModelName: modelName, + RequestCount: delta, + Date: date, + }).Error + } else if err != nil { + return err + } + + return tx.Model(&stats). + Update("request_count", gorm.Expr("request_count + ?", delta)).Error + }) +} + func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { var stats []config.UsageStats query := r.db.Model(&config.UsageStats{}) diff --git a/backend/internal/service/model_service_impl.go b/backend/internal/service/model_service_impl.go index 7ab4c3a..01acb2c 100644 --- a/backend/internal/service/model_service_impl.go +++ b/backend/internal/service/model_service_impl.go @@ -11,27 +11,30 @@ import ( type modelService struct { modelRepo repository.ModelRepository providerRepo repository.ProviderRepository + cache *RoutingCache } -func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService { - return &modelService{modelRepo: modelRepo, providerRepo: providerRepo} +func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService { + return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache} } func (s *modelService) Create(model *domain.Model) error { - // 校验供应商存在 if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil { return appErrors.ErrProviderNotFound } - // 联合唯一校验:同一供应商下 model_name 不重复 if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil { return err } - // 自动生成 UUID 作为 id model.ID = uuid.New().String() model.Enabled = true - return s.modelRepo.Create(model) + err := s.modelRepo.Create(model) + if err != nil { + return err + } + s.cache.SetModel(model) + return nil } func (s *modelService) Get(id string) (*domain.Model, error) { @@ -47,20 +50,17 @@ func (s *modelService) ListEnabled() ([]domain.Model, error) { } func (s *modelService) Update(id string, updates map[string]interface{}) error { - // 获取当前模型 current, err := s.modelRepo.GetByID(id) if err != nil { return appErrors.ErrModelNotFound } - // 如果更新 provider_id,校验新供应商存在 if providerID, ok := updates["provider_id"].(string); ok { if _, err := s.providerRepo.GetByID(providerID); err != nil { return appErrors.ErrProviderNotFound } } - // 确定更新后的 provider_id 和 model_name newProviderID := current.ProviderID if v, ok := updates["provider_id"].(string); ok { newProviderID = v @@ -70,18 +70,37 @@ func (s *modelService) Update(id string, updates map[string]interface{}) error { newModelName = v } - // 如果 provider_id 或 model_name 发生变化,校验联合唯一 if newProviderID != current.ProviderID || newModelName != current.ModelName { if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil { return err } } - return s.modelRepo.Update(id, updates) + err = s.modelRepo.Update(id, updates) + if err != nil { + return err + } + + s.cache.InvalidateModel(current.ProviderID, current.ModelName) + if newProviderID != current.ProviderID || newModelName != current.ModelName { + s.cache.InvalidateModel(newProviderID, newModelName) + } + return nil } func (s *modelService) Delete(id string) error { - return s.modelRepo.Delete(id) + model, err := s.modelRepo.GetByID(id) + if err != nil { + return appErrors.ErrModelNotFound + } + + err = s.modelRepo.Delete(id) + if err != nil { + return err + } + + s.cache.InvalidateModel(model.ProviderID, model.ModelName) + return nil } // checkDuplicateModelName 校验同一供应商下 model_name 是否重复 diff --git a/backend/internal/service/provider_service_impl.go b/backend/internal/service/provider_service_impl.go index 080b540..a01157e 100644 --- a/backend/internal/service/provider_service_impl.go +++ b/backend/internal/service/provider_service_impl.go @@ -13,23 +13,27 @@ import ( type providerService struct { providerRepo repository.ProviderRepository modelRepo repository.ModelRepository + cache *RoutingCache } -func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService { - return &providerService{providerRepo: providerRepo, modelRepo: modelRepo} +func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService { + return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache} } func (s *providerService) Create(provider *domain.Provider) error { - // 校验 provider_id 字符集 if err := modelid.ValidateProviderID(provider.ID); err != nil { return appErrors.ErrInvalidProviderID } provider.Enabled = true err := s.providerRepo.Create(provider) - if err != nil && isUniqueConstraintError(err) { - return appErrors.ErrConflict + if err != nil { + if isUniqueConstraintError(err) { + return appErrors.ErrConflict + } + return err } - return err + s.cache.SetProvider(provider) + return nil } func (s *providerService) Get(id string) (*domain.Provider, error) { @@ -44,11 +48,21 @@ func (s *providerService) Update(id string, updates map[string]interface{}) erro if _, ok := updates["id"]; ok { return appErrors.ErrImmutableField } - return s.providerRepo.Update(id, updates) + err := s.providerRepo.Update(id, updates) + if err != nil { + return err + } + s.cache.InvalidateProvider(id) + return nil } func (s *providerService) Delete(id string) error { - return s.providerRepo.Delete(id) + err := s.providerRepo.Delete(id) + if err != nil { + return err + } + s.cache.InvalidateProvider(id) + return nil } // ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合) diff --git a/backend/internal/service/routing_cache.go b/backend/internal/service/routing_cache.go new file mode 100644 index 0000000..4e5087a --- /dev/null +++ b/backend/internal/service/routing_cache.go @@ -0,0 +1,134 @@ +package service + +import ( + "strings" + "sync" + + "go.uber.org/zap" + + "nex/backend/internal/domain" + "nex/backend/internal/repository" +) + +type RoutingCache struct { + providers sync.Map + models sync.Map + + modelRepo repository.ModelRepository + providerRepo repository.ProviderRepository + logger *zap.Logger +} + +func NewRoutingCache( + modelRepo repository.ModelRepository, + providerRepo repository.ProviderRepository, + logger *zap.Logger, +) *RoutingCache { + return &RoutingCache{ + modelRepo: modelRepo, + providerRepo: providerRepo, + logger: logger, + } +} + +func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) { + if v, ok := c.providers.Load(id); ok { + return v.(*domain.Provider), nil + } + + provider, err := c.providerRepo.GetByID(id) + if err != nil { + return nil, err + } + + if v, ok := c.providers.Load(id); ok { + return v.(*domain.Provider), nil + } + + c.providers.Store(id, provider) + return provider, nil +} + +func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) { + key := providerID + "/" + modelName + + if v, ok := c.models.Load(key); ok { + return v.(*domain.Model), nil + } + + model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName) + if err != nil { + return nil, err + } + + if v, ok := c.models.Load(key); ok { + return v.(*domain.Model), nil + } + + c.models.Store(key, model) + return model, nil +} + +func (c *RoutingCache) SetProvider(provider *domain.Provider) { + c.providers.Store(provider.ID, provider) +} + +func (c *RoutingCache) SetModel(model *domain.Model) { + key := model.ProviderID + "/" + model.ModelName + c.models.Store(key, model) +} + +func (c *RoutingCache) InvalidateProvider(id string) { + c.providers.Delete(id) + c.invalidateModelsByProvider(id) + c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id)) +} + +func (c *RoutingCache) InvalidateModel(providerID, modelName string) { + key := providerID + "/" + modelName + c.models.Delete(key) + c.logger.Debug("Model 缓存失效", + zap.String("provider_id", providerID), + zap.String("model_name", modelName)) +} + +func (c *RoutingCache) invalidateModelsByProvider(providerID string) { + prefix := providerID + "/" + count := 0 + c.models.Range(func(key, value interface{}) bool { + if strings.HasPrefix(key.(string), prefix) { + c.models.Delete(key) + count++ + } + return true + }) + if count > 0 { + c.logger.Debug("清除 Provider 相关 Model 缓存", + zap.String("provider_id", providerID), + zap.Int("count", count)) + } +} + +func (c *RoutingCache) Preload() error { + providers, err := c.providerRepo.List() + if err != nil { + return err + } + for i := range providers { + c.providers.Store(providers[i].ID, &providers[i]) + } + + models, err := c.modelRepo.List("") + if err != nil { + return err + } + for i := range models { + key := models[i].ProviderID + "/" + models[i].ModelName + c.models.Store(key, &models[i]) + } + + c.logger.Info("缓存预热完成", + zap.Int("providers", len(providers)), + zap.Int("models", len(models))) + return nil +} diff --git a/backend/internal/service/routing_cache_test.go b/backend/internal/service/routing_cache_test.go new file mode 100644 index 0000000..0b50df8 --- /dev/null +++ b/backend/internal/service/routing_cache_test.go @@ -0,0 +1,273 @@ +package service + +import ( + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "nex/backend/internal/domain" +) + +type mockModelRepo struct { + models map[string]*domain.Model +} + +func (m *mockModelRepo) Create(model *domain.Model) error { + return nil +} + +func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) { + return nil, nil +} + +func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) { + key := providerID + "/" + modelName + if model, ok := m.models[key]; ok { + return model, nil + } + return nil, errors.New("not found") +} + +func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) { + var result []domain.Model + for _, model := range m.models { + if providerID == "" || model.ProviderID == providerID { + result = append(result, *model) + } + } + return result, nil +} + +func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) { + return nil, nil +} + +func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error { + return nil +} + +func (m *mockModelRepo) Delete(id string) error { + return nil +} + +type mockProviderRepo struct { + providers map[string]*domain.Provider +} + +func (m *mockProviderRepo) Create(provider *domain.Provider) error { + return nil +} + +func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) { + if provider, ok := m.providers[id]; ok { + return provider, nil + } + return nil, errors.New("not found") +} + +func (m *mockProviderRepo) List() ([]domain.Provider, error) { + var result []domain.Provider + for _, provider := range m.providers { + result = append(result, *provider) + } + return result, nil +} + +func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error { + return nil +} + +func (m *mockProviderRepo) Delete(id string) error { + return nil +} + +func TestRoutingCache_GetProvider_CacheHit(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} + providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ + "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, + }} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + provider, err := cache.GetProvider("openai") + require.NoError(t, err) + assert.Equal(t, "openai", provider.ID) + + provider2, err := cache.GetProvider("openai") + require.NoError(t, err) + assert.Equal(t, provider, provider2) +} + +func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} + providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + _, err := cache.GetProvider("notexist") + assert.Error(t, err) +} + +func TestRoutingCache_GetModel_CacheHit(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: map[string]*domain.Model{ + "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, + }} + providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + model, err := cache.GetModel("openai", "gpt-4") + require.NoError(t, err) + assert.Equal(t, "gpt-4", model.ModelName) + + model2, err := cache.GetModel("openai", "gpt-4") + require.NoError(t, err) + assert.Equal(t, model, model2) +} + +func TestRoutingCache_GetModel_CacheMiss(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} + providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + _, err := cache.GetModel("openai", "notexist") + assert.Error(t, err) +} + +func TestRoutingCache_DoubleCheck(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} + providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ + "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, + }} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := cache.GetProvider("openai") + assert.NoError(t, err) + }() + } + wg.Wait() +} + +func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: map[string]*domain.Model{ + "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, + "openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true}, + "anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true}, + }} + providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ + "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, + "anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true}, + }} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + _, err := cache.GetModel("openai", "gpt-4") + require.NoError(t, err) + _, err = cache.GetModel("openai", "gpt-3.5") + require.NoError(t, err) + _, err = cache.GetModel("anthropic", "claude") + require.NoError(t, err) + + cache.InvalidateProvider("openai") + + var openaiCount, anthropicCount int + cache.models.Range(func(key, value interface{}) bool { + if key.(string) == "anthropic/claude" { + anthropicCount++ + } + return true + }) + assert.Equal(t, 0, openaiCount) + assert.Equal(t, 1, anthropicCount) +} + +func TestRoutingCache_InvalidateModel(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: map[string]*domain.Model{ + "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, + }} + providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + _, err := cache.GetModel("openai", "gpt-4") + require.NoError(t, err) + + cache.InvalidateModel("openai", "gpt-4") + + var count int + cache.models.Range(func(key, value interface{}) bool { + count++ + return true + }) + assert.Equal(t, 0, count) +} + +func TestRoutingCache_Preload_Success(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: map[string]*domain.Model{ + "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, + }} + providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ + "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, + }} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + err := cache.Preload() + require.NoError(t, err) + + var providerCount, modelCount int + cache.providers.Range(func(key, value interface{}) bool { + providerCount++ + return true + }) + cache.models.Range(func(key, value interface{}) bool { + modelCount++ + return true + }) + assert.Equal(t, 1, providerCount) + assert.Equal(t, 1, modelCount) +} + +func TestRoutingCache_ConcurrentAccess(t *testing.T) { + logger := zap.NewNop() + modelRepo := &mockModelRepo{models: map[string]*domain.Model{ + "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, + }} + providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ + "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, + }} + + cache := NewRoutingCache(modelRepo, providerRepo, logger) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = cache.GetProvider("openai") + _, _ = cache.GetModel("openai", "gpt-4") + cache.InvalidateProvider("openai") + cache.InvalidateModel("openai", "gpt-4") + }() + } + wg.Wait() +} diff --git a/backend/internal/service/routing_service_impl.go b/backend/internal/service/routing_service_impl.go index 39cf407..f43e006 100644 --- a/backend/internal/service/routing_service_impl.go +++ b/backend/internal/service/routing_service_impl.go @@ -4,20 +4,18 @@ import ( appErrors "nex/backend/pkg/errors" "nex/backend/internal/domain" - "nex/backend/internal/repository" ) type routingService struct { - modelRepo repository.ModelRepository - providerRepo repository.ProviderRepository + cache *RoutingCache } -func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService { - return &routingService{modelRepo: modelRepo, providerRepo: providerRepo} +func NewRoutingService(cache *RoutingCache) RoutingService { + return &routingService{cache: cache} } func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) { - model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) + model, err := s.cache.GetModel(providerID, modelName) if err != nil { return nil, appErrors.ErrModelNotFound } @@ -26,7 +24,7 @@ func (s *routingService) RouteByModelName(providerID, modelName string) (*domain return nil, appErrors.ErrModelDisabled } - provider, err := s.providerRepo.GetByID(model.ProviderID) + provider, err := s.cache.GetProvider(model.ProviderID) if err != nil { return nil, appErrors.ErrProviderNotFound } diff --git a/backend/internal/service/service_supplemental_test.go b/backend/internal/service/service_supplemental_test.go index dd0d72d..63145a1 100644 --- a/backend/internal/service/service_supplemental_test.go +++ b/backend/internal/service/service_supplemental_test.go @@ -14,7 +14,8 @@ func TestProviderService_Update(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewProviderService(repo, modelRepo) + cache := setupRoutingCache(t, db) + svc := NewProviderService(repo, modelRepo, cache) require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})) @@ -30,7 +31,8 @@ func TestProviderService_Update_NotFound(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewProviderService(repo, modelRepo) + cache := setupRoutingCache(t, db) + svc := NewProviderService(repo, modelRepo, cache) err := svc.Update("nonexistent", map[string]interface{}{"name": "test"}) assert.Error(t, err) @@ -40,7 +42,8 @@ func TestModelService_Get(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} @@ -55,7 +58,8 @@ func TestModelService_Update(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} @@ -73,7 +77,8 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} @@ -87,7 +92,8 @@ func TestModelService_Delete(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} @@ -104,7 +110,8 @@ func TestModelService_Delete_NotFound(t *testing.T) { db := setupServiceTestDB(t) modelRepo := repository.NewModelRepository(db) providerRepo := repository.NewProviderRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) err := svc.Delete("nonexistent") assert.Error(t, err) @@ -112,7 +119,8 @@ func TestModelService_Delete_NotFound(t *testing.T) { func TestStatsService_Aggregate_Default(t *testing.T) { statsRepo := repository.NewStatsRepository(nil) - svc := NewStatsService(statsRepo) + buffer := NewStatsBuffer(statsRepo, nil) + svc := NewStatsService(statsRepo, buffer) stats := []domain.UsageStats{ {ProviderID: "p1", RequestCount: 10}, @@ -133,7 +141,8 @@ func TestModelService_Update_NotFound(t *testing.T) { db := setupServiceTestDB(t) modelRepo := repository.NewModelRepository(db) providerRepo := repository.NewProviderRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"}) assert.Error(t, err) diff --git a/backend/internal/service/service_test.go b/backend/internal/service/service_test.go index 860bbfe..b6243c8 100644 --- a/backend/internal/service/service_test.go +++ b/backend/internal/service/service_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" "gorm.io/gorm" testHelpers "nex/backend/tests" @@ -22,13 +23,21 @@ func setupServiceTestDB(t *testing.T) *gorm.DB { return testHelpers.SetupTestDB(t) } +func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache { + t.Helper() + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + return NewRoutingCache(modelRepo, providerRepo, zap.NewNop()) +} + // ============ RoutingService - RouteByModelName 测试 ============ func TestRoutingService_RouteByModelName_Success(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewRoutingService(cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})) @@ -41,9 +50,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) { func TestRoutingService_RouteByModelName_NotFound(t *testing.T) { db := setupServiceTestDB(t) - providerRepo := repository.NewProviderRepository(db) - modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewRoutingService(cache) _, err := svc.RouteByModelName("openai", "nonexistent-model") assert.True(t, errors.Is(err, appErrors.ErrModelNotFound)) @@ -53,7 +61,8 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewRoutingService(cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})) @@ -67,7 +76,8 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewRoutingService(cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})) @@ -83,7 +93,8 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) @@ -91,12 +102,10 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) { err := svc.Create(model) require.NoError(t, err) - // 验证返回的 model 拥有有效的 UUID assert.NotEmpty(t, model.ID) _, err = uuid.Parse(model.ID) assert.NoError(t, err, "model.ID should be a valid UUID") - // 通过 Get 验证持久化 stored, err := svc.Get(model.ID) require.NoError(t, err) assert.Equal(t, model.ID, stored.ID) @@ -107,7 +116,8 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) @@ -115,7 +125,6 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) { err := svc.Create(model1) require.NoError(t, err) - // 使用相同的 (providerID, modelName) 创建第二个模型应失败 model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err = svc.Create(model2) assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel)) @@ -125,7 +134,8 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"} err := svc.Create(model) @@ -138,7 +148,8 @@ func TestProviderService_Create_InvalidID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewProviderService(repo, modelRepo) + cache := setupRoutingCache(t, db) + svc := NewProviderService(repo, modelRepo, cache) provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) @@ -149,7 +160,8 @@ func TestProviderService_Create_ValidID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewProviderService(repo, modelRepo) + cache := setupRoutingCache(t, db) + svc := NewProviderService(repo, modelRepo, cache) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) @@ -164,7 +176,8 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})) @@ -189,7 +202,8 @@ func TestModelService_Update_ModelNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) err := svc.Update("nonexistent-id", map[string]interface{}{ "model_name": "gpt-4", @@ -201,7 +215,8 @@ func TestModelService_Update_Success(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) @@ -226,7 +241,8 @@ func TestProviderService_Update_ImmutableID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewProviderService(repo, modelRepo) + cache := setupRoutingCache(t, db) + svc := NewProviderService(repo, modelRepo, cache) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) @@ -243,7 +259,8 @@ func TestProviderService_Update_Success(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewProviderService(repo, modelRepo) + cache := setupRoutingCache(t, db) + svc := NewProviderService(repo, modelRepo, cache) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) @@ -301,7 +318,7 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db := setupServiceTestDB(t) statsRepo := repository.NewStatsRepository(db) - svc := NewStatsService(statsRepo) + buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer) result := svc.Aggregate(tt.stats, "model") @@ -362,7 +379,7 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db := setupServiceTestDB(t) statsRepo := repository.NewStatsRepository(db) - svc := NewStatsService(statsRepo) + buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer) result := svc.Aggregate(tt.stats, "date") @@ -431,7 +448,8 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewProviderService(repo, modelRepo) + cache := setupRoutingCache(t, db) + svc := NewProviderService(repo, modelRepo, cache) provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"} provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"} @@ -456,7 +474,8 @@ func TestModelService_ConcurrentCreate(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewModelService(modelRepo, providerRepo) + cache := setupRoutingCache(t, db) + svc := NewModelService(modelRepo, providerRepo, cache) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) diff --git a/backend/internal/service/stats_buffer.go b/backend/internal/service/stats_buffer.go new file mode 100644 index 0000000..8728e04 --- /dev/null +++ b/backend/internal/service/stats_buffer.go @@ -0,0 +1,169 @@ +package service + +import ( + "strings" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" + + "nex/backend/internal/repository" +) + +type StatsBuffer struct { + counters sync.Map + + flushInterval time.Duration + flushThreshold int + totalCount atomic.Int64 + + statsRepo repository.StatsRepository + logger *zap.Logger + + stopCh chan struct{} + doneCh chan struct{} +} + +type StatsBufferOption func(*StatsBuffer) + +func WithFlushInterval(d time.Duration) StatsBufferOption { + return func(b *StatsBuffer) { + b.flushInterval = d + } +} + +func WithFlushThreshold(threshold int) StatsBufferOption { + return func(b *StatsBuffer) { + b.flushThreshold = threshold + } +} + +func NewStatsBuffer( + statsRepo repository.StatsRepository, + logger *zap.Logger, + opts ...StatsBufferOption, +) *StatsBuffer { + b := &StatsBuffer{ + statsRepo: statsRepo, + logger: logger, + flushInterval: 5 * time.Second, + flushThreshold: 100, + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + + for _, opt := range opts { + opt(b) + } + + return b +} + +func (b *StatsBuffer) Increment(providerID, modelName string) { + today := time.Now().Format("2006-01-02") + key := providerID + "/" + modelName + "/" + today + + var counter *int64 + if v, ok := b.counters.Load(key); ok { + counter = v.(*int64) + } else { + val := int64(0) + counter = &val + actual, loaded := b.counters.LoadOrStore(key, counter) + if loaded { + counter = actual.(*int64) + } + } + + atomic.AddInt64(counter, 1) + + if b.totalCount.Add(1) >= int64(b.flushThreshold) { + go b.flush() + } +} + +func (b *StatsBuffer) Start() { + go func() { + ticker := time.NewTicker(b.flushInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + b.flush() + case <-b.stopCh: + b.flush() + close(b.doneCh) + return + } + } + }() +} + +func (b *StatsBuffer) Stop() { + close(b.stopCh) + <-b.doneCh +} + +func (b *StatsBuffer) flush() { + type statEntry struct { + providerID string + modelName string + date string + count int64 + } + + var entries []statEntry + b.counters.Range(func(key, value interface{}) bool { + keyStr := key.(string) + parts := strings.Split(keyStr, "/") + if len(parts) != 3 { + return true + } + + counter := value.(*int64) + count := atomic.SwapInt64(counter, 0) + + if count > 0 { + entries = append(entries, statEntry{ + providerID: parts[0], + modelName: parts[1], + date: parts[2], + count: count, + }) + } + return true + }) + + if len(entries) == 0 { + return + } + + success := 0 + for _, entry := range entries { + date, _ := time.Parse("2006-01-02", entry.date) + err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count)) + if err != nil { + b.logger.Error("批量更新统计失败", + zap.String("provider_id", entry.providerID), + zap.String("model_name", entry.modelName), + zap.Int64("count", entry.count), + zap.Error(err)) + + key := entry.providerID + "/" + entry.modelName + "/" + entry.date + if v, ok := b.counters.Load(key); ok { + counter := v.(*int64) + atomic.AddInt64(counter, entry.count) + } + } else { + success++ + } + } + + b.totalCount.Store(0) + b.logger.Debug("统计刷新完成", + zap.Int("total", len(entries)), + zap.Int("success", success), + zap.Int("failed", len(entries)-success)) +} diff --git a/backend/internal/service/stats_buffer_test.go b/backend/internal/service/stats_buffer_test.go new file mode 100644 index 0000000..4f789c4 --- /dev/null +++ b/backend/internal/service/stats_buffer_test.go @@ -0,0 +1,251 @@ +package service + +import ( + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + + "nex/backend/internal/domain" +) + +type mockStatsRepo struct { + records []struct { + providerID string + modelName string + date time.Time + delta int + } + fail bool + mu sync.Mutex +} + +func (m *mockStatsRepo) Record(providerID, modelName string) error { + return nil +} + +func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.fail { + return errors.New("db error") + } + m.records = append(m.records, struct { + providerID string + modelName string + date time.Time + delta int + }{providerID, modelName, date, delta}) + return nil +} + +func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { + return nil, nil +} + +func TestStatsBuffer_Increment(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger) + + buffer.Increment("openai", "gpt-4") + buffer.Increment("openai", "gpt-4") + buffer.Increment("openai", "gpt-3.5") + + var count int64 + buffer.counters.Range(func(key, value interface{}) bool { + counter := value.(*int64) + count += atomic.LoadInt64(counter) + return true + }) + assert.Equal(t, int64(3), count) +} + +func TestStatsBuffer_ConcurrentIncrement(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buffer.Increment("openai", "gpt-4") + }() + } + wg.Wait() + + var count int64 + buffer.counters.Range(func(key, value interface{}) bool { + counter := value.(*int64) + count = atomic.LoadInt64(counter) + return true + }) + assert.Equal(t, int64(100), count) +} + +func TestStatsBuffer_LoadOrStore(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buffer.Increment("openai", "gpt-4") + }() + } + wg.Wait() + + var counterCount int + buffer.counters.Range(func(key, value interface{}) bool { + counterCount++ + return true + }) + assert.Equal(t, 1, counterCount) +} + +func TestStatsBuffer_FlushByInterval(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger, + WithFlushInterval(100*time.Millisecond)) + + buffer.Start() + defer buffer.Stop() + + buffer.Increment("openai", "gpt-4") + buffer.Increment("openai", "gpt-4") + + time.Sleep(200 * time.Millisecond) + + statsRepo.mu.Lock() + assert.GreaterOrEqual(t, len(statsRepo.records), 1) + statsRepo.mu.Unlock() +} + +func TestStatsBuffer_FlushByThreshold(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger, + WithFlushThreshold(10)) + + buffer.Start() + defer buffer.Stop() + + for i := 0; i < 10; i++ { + buffer.Increment("openai", "gpt-4") + } + + time.Sleep(50 * time.Millisecond) + + statsRepo.mu.Lock() + assert.GreaterOrEqual(t, len(statsRepo.records), 1) + statsRepo.mu.Unlock() +} + +func TestStatsBuffer_SwapInt64(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger) + + buffer.Increment("openai", "gpt-4") + buffer.Increment("openai", "gpt-4") + + var beforeCount int64 + buffer.counters.Range(func(key, value interface{}) bool { + counter := value.(*int64) + beforeCount = atomic.LoadInt64(counter) + return true + }) + assert.Equal(t, int64(2), beforeCount) + + buffer.flush() + + var afterCount int64 + buffer.counters.Range(func(key, value interface{}) bool { + counter := value.(*int64) + afterCount = atomic.LoadInt64(counter) + return true + }) + assert.Equal(t, int64(0), afterCount) +} + +func TestStatsBuffer_FailRetry(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{fail: true} + buffer := NewStatsBuffer(statsRepo, logger) + + buffer.Increment("openai", "gpt-4") + buffer.Increment("openai", "gpt-4") + + buffer.flush() + + var count int64 + buffer.counters.Range(func(key, value interface{}) bool { + counter := value.(*int64) + count = atomic.LoadInt64(counter) + return true + }) + assert.Equal(t, int64(2), count) +} + +func TestStatsBuffer_Stop(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger, + WithFlushInterval(10*time.Second)) + + buffer.Start() + + buffer.Increment("openai", "gpt-4") + buffer.Increment("openai", "gpt-4") + + start := time.Now() + buffer.Stop() + elapsed := time.Since(start) + + assert.Less(t, elapsed, 1*time.Second) + + statsRepo.mu.Lock() + assert.GreaterOrEqual(t, len(statsRepo.records), 1) + statsRepo.mu.Unlock() +} + +func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) { + logger := zap.NewNop() + statsRepo := &mockStatsRepo{} + buffer := NewStatsBuffer(statsRepo, logger, + WithFlushInterval(50*time.Millisecond)) + + buffer.Start() + defer buffer.Stop() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + buffer.Increment("openai", "gpt-4") + }() + } + wg.Wait() + + time.Sleep(100 * time.Millisecond) + + statsRepo.mu.Lock() + totalDelta := 0 + for _, r := range statsRepo.records { + totalDelta += r.delta + } + statsRepo.mu.Unlock() + + assert.Equal(t, 100, totalDelta) +} diff --git a/backend/internal/service/stats_service_impl.go b/backend/internal/service/stats_service_impl.go index 331d4d4..610c072 100644 --- a/backend/internal/service/stats_service_impl.go +++ b/backend/internal/service/stats_service_impl.go @@ -10,14 +10,16 @@ import ( type statsService struct { statsRepo repository.StatsRepository + buffer *StatsBuffer } -func NewStatsService(statsRepo repository.StatsRepository) StatsService { - return &statsService{statsRepo: statsRepo} +func NewStatsService(statsRepo repository.StatsRepository, buffer *StatsBuffer) StatsService { + return &statsService{statsRepo: statsRepo, buffer: buffer} } func (s *statsService) Record(providerID, modelName string) error { - return s.statsRepo.Record(providerID, modelName) + s.buffer.Increment(providerID, modelName) + return nil } func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { diff --git a/backend/tests/integration/conversion_test.go b/backend/tests/integration/conversion_test.go index f5834ea..fbaf836 100644 --- a/backend/tests/integration/conversion_test.go +++ b/backend/tests/integration/conversion_test.go @@ -14,6 +14,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" "gorm.io/gorm" "nex/backend/internal/conversion" @@ -54,10 +55,14 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server) modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - providerService := service.NewProviderService(providerRepo, modelRepo) - modelService := service.NewModelService(modelRepo, providerRepo) - routingService := service.NewRoutingService(modelRepo, providerRepo) - statsService := service.NewStatsService(statsRepo) + logger := zap.NewNop() + routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger) + statsBuffer := service.NewStatsBuffer(statsRepo, logger) + + providerService := service.NewProviderService(providerRepo, modelRepo, routingCache) + modelService := service.NewModelService(modelRepo, providerRepo, routingCache) + routingService := service.NewRoutingService(routingCache) + statsService := service.NewStatsService(statsRepo, statsBuffer) // 创建 ConversionEngine registry := conversion.NewMemoryRegistry() diff --git a/backend/tests/integration/e2e_conversion_test.go b/backend/tests/integration/e2e_conversion_test.go index f8109d4..ea451cb 100644 --- a/backend/tests/integration/e2e_conversion_test.go +++ b/backend/tests/integration/e2e_conversion_test.go @@ -15,6 +15,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" "nex/backend/internal/conversion" "nex/backend/internal/conversion/anthropic" @@ -48,10 +49,14 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) { modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - providerService := service.NewProviderService(providerRepo, modelRepo) - modelService := service.NewModelService(modelRepo, providerRepo) - routingService := service.NewRoutingService(modelRepo, providerRepo) - statsService := service.NewStatsService(statsRepo) + logger := zap.NewNop() + routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger) + statsBuffer := service.NewStatsBuffer(statsRepo, logger) + + providerService := service.NewProviderService(providerRepo, modelRepo, routingCache) + modelService := service.NewModelService(modelRepo, providerRepo, routingCache) + routingService := service.NewRoutingService(routingCache) + statsService := service.NewStatsService(statsRepo, statsBuffer) registry := conversion.NewMemoryRegistry() require.NoError(t, registry.Register(openaiConv.NewAdapter())) diff --git a/backend/tests/integration/integration_test.go b/backend/tests/integration/integration_test.go index 13916d9..1c8e5e3 100644 --- a/backend/tests/integration/integration_test.go +++ b/backend/tests/integration/integration_test.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "go.uber.org/zap" "gorm.io/gorm" "nex/backend/internal/domain" @@ -30,10 +31,14 @@ func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) { modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - providerService := service.NewProviderService(providerRepo, modelRepo) - modelService := service.NewModelService(modelRepo, providerRepo) - _ = service.NewRoutingService(modelRepo, providerRepo) - statsService := service.NewStatsService(statsRepo) + logger := zap.NewNop() + routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger) + statsBuffer := service.NewStatsBuffer(statsRepo, logger) + + providerService := service.NewProviderService(providerRepo, modelRepo, routingCache) + modelService := service.NewModelService(modelRepo, providerRepo, routingCache) + _ = service.NewRoutingService(routingCache) + statsService := service.NewStatsService(statsRepo, statsBuffer) providerHandler := handler.NewProviderHandler(providerService) modelHandler := handler.NewModelHandler(modelService) diff --git a/backend/tests/mocks/mock_stats_repository.go b/backend/tests/mocks/mock_stats_repository.go index d031185..45742f4 100644 --- a/backend/tests/mocks/mock_stats_repository.go +++ b/backend/tests/mocks/mock_stats_repository.go @@ -41,6 +41,20 @@ func (m *MockStatsRepository) EXPECT() *MockStatsRepositoryMockRecorder { return m.recorder } +// BatchUpdate mocks base method. +func (m *MockStatsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BatchUpdate", providerID, modelName, date, delta) + ret0, _ := ret[0].(error) + return ret0 +} + +// BatchUpdate indicates an expected call of BatchUpdate. +func (mr *MockStatsRepositoryMockRecorder) BatchUpdate(providerID, modelName, date, delta any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdate", reflect.TypeOf((*MockStatsRepository)(nil).BatchUpdate), providerID, modelName, date, delta) +} + // Query mocks base method. func (m *MockStatsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { m.ctrl.T.Helper() diff --git a/openspec/specs/routing-cache/spec.md b/openspec/specs/routing-cache/spec.md new file mode 100644 index 0000000..13edb96 --- /dev/null +++ b/openspec/specs/routing-cache/spec.md @@ -0,0 +1,136 @@ +# Routing Cache + +## Purpose + +TBD - 为 Provider 和 Model 提供内存缓存,优化路由查询性能。 + +## Requirements + +### Requirement: RoutingCache 缓存 Provider 和 Model + +系统 SHALL 为 Provider 和 Model 提供内存缓存,使用 sync.Map 作为缓存数据结构。 + +#### Scenario: 缓存数据结构 + +- **WHEN** 创建 RoutingCache +- **THEN** SHALL 使用 sync.Map 存储 Provider(key = providerID) +- **THEN** SHALL 使用 sync.Map 存储 Model(key = providerID/modelName) + +### Requirement: 缓存查询优先 + +RoutingCache SHALL 优先从缓存查询,缓存未命中时查询数据库并更新缓存。 + +#### Scenario: 缓存命中 + +- **WHEN** 查询 Provider 或 Model +- **THEN** SHALL 先从缓存查询 +- **THEN** 如果缓存命中,SHALL 直接返回缓存值 +- **THEN** SHALL 不查询数据库 + +#### Scenario: 缓存未命中 + +- **WHEN** 查询 Provider 或 Model +- **THEN** SHALL 先从缓存查询 +- **THEN** 如果缓存未命中,SHALL 查询数据库 +- **THEN** SHALL 将查询结果存入缓存 +- **THEN** SHALL 返回查询结果 + +#### Scenario: 双重检查防止竞态 + +- **WHEN** 并发查询同一个 Provider 或 Model +- **THEN** SHALL 在查询数据库后再次检查缓存 +- **THEN** 如果已有其他 goroutine 存入缓存,SHALL 使用缓存值 +- **THEN** SHALL 防止存入旧值 + +### Requirement: 缓存更新 + +RoutingCache SHALL 支持 Create 操作后更新缓存。 + +#### Scenario: Create Provider 后更新缓存 + +- **WHEN** 创建 Provider 成功 +- **THEN** SHALL 调用 RoutingCache.SetProvider +- **THEN** SHALL 将新 Provider 存入缓存 + +#### Scenario: Create Model 后更新缓存 + +- **WHEN** 创建 Model 成功 +- **THEN** SHALL 调用 RoutingCache.SetModel +- **THEN** SHALL 将新 Model 存入缓存 + +### Requirement: 缓存失效 + +RoutingCache SHALL 支持 Update/Delete 操作后清除缓存。 + +#### Scenario: Update Provider 后清除缓存 + +- **WHEN** 更新 Provider 成功 +- **THEN** SHALL 调用 RoutingCache.InvalidateProvider +- **THEN** SHALL 清除该 Provider 的缓存 +- **THEN** SHALL 级联清除该 Provider 的所有 Model 缓存 + +#### Scenario: Delete Provider 后清除缓存 + +- **WHEN** 删除 Provider 成功 +- **THEN** SHALL 调用 RoutingCache.InvalidateProvider +- **THEN** SHALL 清除该 Provider 的缓存 +- **THEN** SHALL 级联清除该 Provider 的所有 Model 缓存 + +#### Scenario: Update Model 后清除缓存 + +- **WHEN** 更新 Model 成功 +- **THEN** SHALL 调用 RoutingCache.InvalidateModel +- **THEN** SHALL 清除旧位置的 Model 缓存 +- **THEN** 如果 provider_id 或 model_name 变化,SHALL 也清除新位置的缓存 + +#### Scenario: Delete Model 后清除缓存 + +- **WHEN** 删除 Model 成功 +- **THEN** SHALL 调用 RoutingCache.InvalidateModel +- **THEN** SHALL 清除该 Model 的缓存 + +### Requirement: 缓存预热 + +RoutingCache SHALL 支持启动时预热缓存。 + +#### Scenario: 预热成功 + +- **WHEN** 服务启动时 +- **THEN** SHALL 调用 RoutingCache.Preload +- **THEN** SHALL 从数据库加载所有 Provider 到缓存 +- **THEN** SHALL 从数据库加载所有 Model 到缓存 +- **THEN** SHALL 记录预热完成的日志 + +#### Scenario: 预热失败 + +- **WHEN** 预热失败时 +- **THEN** SHALL 记录警告日志 +- **THEN** SHALL 继续启动服务 +- **THEN** SHALL 使用懒加载(首次查询时加载) + +### Requirement: RoutingService 使用缓存 + +RoutingService SHALL 使用 RoutingCache 进行路由查询。 + +#### Scenario: RouteByModelName 使用缓存 + +- **WHEN** 调用 RoutingService.RouteByModelName +- **THEN** SHALL 调用 RoutingCache.GetModel 获取 Model +- **THEN** SHALL 调用 RoutingCache.GetProvider 获取 Provider +- **THEN** SHALL 不直接调用 Repository + +### Requirement: 并发安全 + +RoutingCache SHALL 支持并发访问。 + +#### Scenario: 并发查询 + +- **WHEN** 多个 goroutine 并发查询缓存 +- **THEN** SHALL 无竞态条件 +- **THEN** SHALL 无 panic + +#### Scenario: 并发查询和失效 + +- **WHEN** 并发查询和失效缓存 +- **THEN** SHALL 无竞态条件 +- **THEN** SHALL 保证一致性 diff --git a/openspec/specs/stats-buffer/spec.md b/openspec/specs/stats-buffer/spec.md new file mode 100644 index 0000000..84fdbee --- /dev/null +++ b/openspec/specs/stats-buffer/spec.md @@ -0,0 +1,156 @@ +# Stats Buffer + +## Purpose + +TBD - 为统计数据提供内存缓冲,优化写入性能。 + +## Requirements + +### Requirement: StatsBuffer 内存缓冲 + +系统 SHALL 为统计数据提供内存缓冲,使用 sync.Map + atomic.Int64 进行计数。 + +#### Scenario: 缓冲数据结构 + +- **WHEN** 创建 StatsBuffer +- **THEN** SHALL 使用 sync.Map 存储计数器(key = providerID/modelName/date) +- **THEN** SHALL 使用 atomic.Int64 进行计数 +- **THEN** SHALL 支持配置刷新间隔和阈值 + +### Requirement: 原子计数 + +StatsBuffer SHALL 使用原子操作进行计数。 + +#### Scenario: Increment 计数 + +- **WHEN** 调用 StatsBuffer.Increment +- **THEN** SHALL 使用 atomic.AddInt64 增加计数 +- **THEN** SHALL 无锁操作 +- **THEN** SHALL 线程安全 + +#### Scenario: Increment 创建计数器 + +- **WHEN** 调用 StatsBuffer.Increment 且计数器不存在 +- **THEN** SHALL 使用 sync.Map.LoadOrStore 创建计数器 +- **THEN** SHALL 初始化计数器为 0 +- **THEN** SHALL 原子增加计数 + +#### Scenario: 并发计数 + +- **WHEN** 多个 goroutine 并发 Increment +- **THEN** SHALL 计数准确 +- **THEN** SHALL 无竞态条件 + +### Requirement: 定时刷新 + +StatsBuffer SHALL 支持定时刷新到数据库。 + +#### Scenario: 定时刷新触发 + +- **WHEN** 后台刷新协程运行 +- **THEN** SHALL 每隔 flushInterval 触发刷新 +- **THEN** SHALL 调用 StatsRepository.BatchUpdate 写入数据库 + +#### Scenario: 刷新间隔配置 + +- **WHEN** 创建 StatsBuffer +- **THEN** 默认 flushInterval 为 5 秒 +- **THEN** 可通过 WithFlushInterval 选项配置 + +### Requirement: 阈值触发刷新 + +StatsBuffer SHALL 支持累计阈值触发刷新。 + +#### Scenario: 阈值触发 + +- **WHEN** 累计计数达到 flushThreshold +- **THEN** SHALL 异步触发刷新 +- **THEN** SHALL 不阻塞请求 + +#### Scenario: 阈值配置 + +- **WHEN** 创建 StatsBuffer +- **THEN** 默认 flushThreshold 为 100 +- **THEN** 可通过 WithFlushThreshold 选项配置 + +### Requirement: 批量写入数据库 + +StatsBuffer SHALL 批量写入统计数据到数据库。 + +#### Scenario: 批量写入 + +- **WHEN** 刷新触发 +- **THEN** SHALL 遍历所有计数器 +- **THEN** SHALL 使用 atomic.SwapInt64(counter, 0) 获取并清零计数器 +- **THEN** SHALL 调用 StatsRepository.BatchUpdate 批量写入 +- **THEN** SHALL 重置 totalCount 为 0 + +#### Scenario: SwapInt64 清零计数器 + +- **WHEN** flush 收集计数器 +- **THEN** SHALL 使用 SwapInt64 原子操作获取当前计数并清零 +- **THEN** SHALL 保证计数不丢失(新计数会累加到已清零的计数器) +- **THEN** SHALL 不阻塞后续 Increment 操作 + +#### Scenario: 写入失败保留计数器 + +- **WHEN** BatchUpdate 失败 +- **THEN** SHALL 将计数加回计数器(使用 atomic.AddInt64) +- **THEN** SHALL 记录错误日志 +- **THEN** SHALL 继续处理其他条目 + +### Requirement: 优雅关闭 + +StatsBuffer SHALL 支持优雅关闭,确保最后的统计写入数据库。 + +#### Scenario: Stop 等待刷新完成 + +- **WHEN** 调用 StatsBuffer.Stop +- **THEN** SHALL 停止后台刷新协程 +- **THEN** SHALL 执行最后一次刷新 +- **THEN** SHALL 等待刷新完成 +- **THEN** SHALL 保证统计数据不丢失 + +#### Scenario: 无超时 + +- **WHEN** Stop 等待刷新 +- **THEN** SHALL 无超时限制 +- **THEN** SHALL 等待刷新完成 + +### Requirement: StatsService 使用缓冲 + +StatsService SHALL 使用 StatsBuffer 进行统计记录。 + +#### Scenario: Record 使用缓冲 + +- **WHEN** 调用 StatsService.Record +- **THEN** SHALL 调用 StatsBuffer.Increment +- **THEN** SHALL 不直接调用 StatsRepository.Record +- **THEN** SHALL 立即返回,不阻塞 + +### Requirement: StatsRepository 扩展 + +StatsRepository SHALL 新增 BatchUpdate 方法。 + +#### Scenario: BatchUpdate 方法 + +- **WHEN** 调用 StatsRepository.BatchUpdate +- **THEN** SHALL 使用事务更新或创建统计记录 +- **THEN** SHALL 使用 request_count + delta 更新 +- **THEN** SHALL 支持批量增量更新 + +### Requirement: 并发安全 + +StatsBuffer SHALL 支持并发访问。 + +#### Scenario: 并发 Increment 和 flush + +- **WHEN** 并发 Increment 和 flush +- **THEN** SHALL 无竞态条件 +- **THEN** SHALL 计数准确(可能延迟到下次 flush) + +#### Scenario: flush 期间 Increment + +- **WHEN** flush 正在执行 +- **THEN** 新的 Increment SHALL 继续计数 +- **THEN** SHALL 不会丢失计数 diff --git a/openspec/specs/usage-statistics/spec.md b/openspec/specs/usage-statistics/spec.md index 7d7ff23..7c5d4d0 100644 --- a/openspec/specs/usage-statistics/spec.md +++ b/openspec/specs/usage-statistics/spec.md @@ -13,14 +13,15 @@ #### Scenario: 记录成功请求 - **WHEN** 请求成功转发到供应商 -- **THEN** 网关 SHALL 增加该供应商和模型的请求计数 -- **THEN** 网关 SHALL 记录当前日期的统计 +- **THEN** 网关 SHALL 通过 StatsBuffer 增加该供应商和模型的请求计数 +- **THEN** 网关 SHALL 异步批量写入数据库(定时或阈值触发) +- **THEN** 网关 SHALL 不阻塞响应 #### Scenario: 记录流式请求 - **WHEN** 流式请求成功完成 -- **THEN** 网关 SHALL 增加该供应商和模型的请求计数 -- **THEN** 网关 SHALL 在流结束后记录统计 +- **THEN** 网关 SHALL 通过 StatsBuffer 增加该供应商和模型的请求计数 +- **THEN** 网关 SHALL 在流结束后异步记录统计 ### Requirement: 使用统计记录统一模型标识 @@ -90,10 +91,9 @@ #### Scenario: 并发请求 - **WHEN** 同时处理多个并发请求 -- **THEN** 网关 SHALL 正确为每个请求增加请求计数 +- **THEN** 网关 SHALL 使用原子操作正确增加每个请求的计数 - **THEN** 不 SHALL 因并发写入而丢失统计 - -**变更说明:** 并发控制在 StatsRepository 中通过数据库事务实现。API 接口保持不变。 +- **THEN** SHALL 使用 StatsBuffer 的内存计数器 ### Requirement: 使用 service 层处理业务逻辑 @@ -121,6 +121,13 @@ Service SHALL 通过 StatsRepository 访问数据。 - **THEN** SHALL 调用对应的 StatsRepository 方法 - **THEN** SHALL 使用 domain.UsageStats 类型 +#### Scenario: 批量更新统计 + +- **WHEN** StatsBuffer 刷新统计 +- **THEN** SHALL 调用 StatsRepository.BatchUpdate +- **THEN** SHALL 使用事务更新或创建统计记录 +- **THEN** SHALL 支持增量更新(request_count + delta) + #### Scenario: 事务处理 - **WHEN** 记录统计 @@ -136,3 +143,36 @@ Service SHALL 通过 StatsRepository 访问数据。 - **WHEN** 查询统计 - **THEN** SHALL 使用 (provider_id, model_name, date) 复合索引 - **THEN** SHALL 优化查询性能 + +### Requirement: 统计数据可接受少量丢失 + +统计记录方式改为内存缓冲,可接受少量丢失。 + +#### Scenario: 进程崩溃丢失统计 + +- **WHEN** 进程崩溃 +- **THEN** MAY 丢失最近 flushInterval 内的统计 +- **THEN** 统计数据用于监控,可接受少量丢失 + +#### Scenario: 优雅关闭保证不丢失 + +- **WHEN** 服务优雅关闭 +- **THEN** SHALL 调用 StatsBuffer.Stop +- **THEN** SHALL 等待最后一次刷新完成 +- **THEN** SHALL 保证统计数据不丢失 + +### Requirement: StatsRepository 支持批量更新 + +StatsRepository SHALL 新增 BatchUpdate 方法支持批量增量更新。 + +#### Scenario: BatchUpdate 更新现有记录 + +- **WHEN** 调用 BatchUpdate 且当日记录存在 +- **THEN** SHALL 使用事务更新 request_count = request_count + delta +- **THEN** SHALL 不创建新记录 + +#### Scenario: BatchUpdate 创建新记录 + +- **WHEN** 调用 BatchUpdate 且当日记录不存在 +- **THEN** SHALL 创建新记录,request_count = delta +- **THEN** SHALL 使用事务保证原子性