package service import ( "errors" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" testHelpers "nex/backend/tests" "nex/backend/internal/domain" "nex/backend/internal/repository" appErrors "nex/backend/pkg/errors" ) func setupServiceTestDB(t *testing.T) *gorm.DB { t.Helper() return testHelpers.SetupTestDB(t) } // ============ RoutingService - RouteByModelName 测试 ============ func TestRoutingService_RouteByModelName_Success(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})) result, err := svc.RouteByModelName("openai", "gpt-4") require.NoError(t, err) assert.Equal(t, "openai", result.Provider.ID) assert.Equal(t, "gpt-4", result.Model.ModelName) } func TestRoutingService_RouteByModelName_NotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) _, err := svc.RouteByModelName("openai", "nonexistent-model") assert.True(t, errors.Is(err, appErrors.ErrModelNotFound)) } func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})) require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false})) _, err := svc.RouteByModelName("openai", "gpt-4") assert.True(t, errors.Is(err, appErrors.ErrModelDisabled)) } func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})) require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})) require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false})) _, err := svc.RouteByModelName("openai", "gpt-4") assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled)) } // ============ ModelService - Create with UUID 测试 ============ func TestModelService_Create_GeneratesUUID(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model) require.NoError(t, err) // 验证返回的 model 拥有有效的 UUID assert.NotEmpty(t, model.ID) _, err = uuid.Parse(model.ID) assert.NoError(t, err, "model.ID should be a valid UUID") // 通过 Get 验证持久化 stored, err := svc.Get(model.ID) require.NoError(t, err) assert.Equal(t, model.ID, stored.ID) assert.Equal(t, "gpt-4", stored.ModelName) } func TestModelService_Create_DuplicateModelName(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model1) require.NoError(t, err) // 使用相同的 (providerID, modelName) 创建第二个模型应失败 model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err = svc.Create(model2) assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel)) } func TestModelService_Create_ProviderNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"} err := svc.Create(model) assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound)) } // ============ ProviderService - Create with validation 测试 ============ func TestProviderService_Create_InvalidID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID)) } func TestProviderService_Create_ValidID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) require.NoError(t, err) assert.Equal(t, "openai", provider.ID) assert.True(t, provider.Enabled) } // ============ ModelService - Update with duplicate check 测试 ============ func TestModelService_Update_DuplicateModelName(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})) model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model1) require.NoError(t, err) model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"} err = svc.Create(model2) require.NoError(t, err) // 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突 err = svc.Update(model2.ID, map[string]interface{}{ "provider_id": "openai", "model_name": "gpt-4", }) assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel)) } func TestModelService_Update_ModelNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) err := svc.Update("nonexistent-id", map[string]interface{}{ "model_name": "gpt-4", }) assert.True(t, errors.Is(err, appErrors.ErrModelNotFound)) } func TestModelService_Update_Success(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model) require.NoError(t, err) // 更新 model_name 为不冲突的值 err = svc.Update(model.ID, map[string]interface{}{ "model_name": "gpt-4-turbo", }) require.NoError(t, err) updated, err := svc.Get(model.ID) require.NoError(t, err) assert.Equal(t, "gpt-4-turbo", updated.ModelName) } // ============ ProviderService - Update immutable ID 测试 ============ func TestProviderService_Update_ImmutableID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) require.NoError(t, err) // 尝试更新 id 字段 err = svc.Update("openai", map[string]interface{}{ "id": "new-id", }) assert.True(t, errors.Is(err, appErrors.ErrImmutableField)) } func TestProviderService_Update_Success(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) require.NoError(t, err) // 更新 name err = svc.Update("openai", map[string]interface{}{ "name": "OpenAI Updated", }) require.NoError(t, err) updated, err := svc.Get("openai", false) require.NoError(t, err) assert.Equal(t, "OpenAI Updated", updated.Name) } // ============ StatsService - Aggregate ByModel 测试 ============ func TestStatsService_Aggregate_ByModel(t *testing.T) { tests := []struct { name string stats []domain.UsageStats expected []map[string]interface{} }{ { name: "multiple providers with same model name", stats: []domain.UsageStats{ {ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10}, {ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20}, {ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5}, }, expected: []map[string]interface{}{ {"provider_id": "openai", "model_name": "gpt-4", "request_count": 15}, {"provider_id": "azure", "model_name": "gpt-4", "request_count": 20}, }, }, { name: "empty providerID", stats: []domain.UsageStats{ {ProviderID: "", ModelName: "gpt-4", RequestCount: 10}, {ProviderID: "", ModelName: "gpt-4", RequestCount: 5}, }, expected: []map[string]interface{}{ {"provider_id": "", "model_name": "gpt-4", "request_count": 15}, }, }, { name: "empty result set", stats: []domain.UsageStats{}, expected: []map[string]interface{}{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := setupServiceTestDB(t) statsRepo := repository.NewStatsRepository(db) svc := NewStatsService(statsRepo) result := svc.Aggregate(tt.stats, "model") assert.Len(t, result, len(tt.expected)) for _, exp := range tt.expected { found := false for _, r := range result { if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] { assert.Equal(t, exp["request_count"], r["request_count"]) found = true break } } assert.True(t, found, "expected result not found: %v", exp) } }) } } // ============ StatsService - Aggregate ByDate 测试 ============ func TestStatsService_Aggregate_ByDate(t *testing.T) { tests := []struct { name string stats []domain.UsageStats expected []map[string]interface{} }{ { name: "normal date grouping", stats: []domain.UsageStats{ {Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10}, {Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5}, {Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20}, }, expected: []map[string]interface{}{ {"date": "2024-01-01", "request_count": 15}, {"date": "2024-01-02", "request_count": 20}, }, }, { name: "zero-value time", stats: []domain.UsageStats{ {Date: time.Time{}, RequestCount: 10}, {Date: time.Time{}, RequestCount: 5}, }, expected: []map[string]interface{}{ {"date": "0001-01-01", "request_count": 15}, }, }, { name: "empty result set", stats: []domain.UsageStats{}, expected: []map[string]interface{}{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := setupServiceTestDB(t) statsRepo := repository.NewStatsRepository(db) svc := NewStatsService(statsRepo) result := svc.Aggregate(tt.stats, "date") assert.Len(t, result, len(tt.expected)) for _, exp := range tt.expected { found := false for _, r := range result { if r["date"] == exp["date"] { assert.Equal(t, exp["request_count"], r["request_count"]) found = true break } } assert.True(t, found, "expected result not found: %v", exp) } }) } } // ============ ProviderService - isUniqueConstraintError 测试 ============ func TestProviderService_isUniqueConstraintError(t *testing.T) { tests := []struct { name string err error expected bool }{ { name: "nil error", err: nil, expected: false, }, { name: "UNIQUE constraint failed", err: errors.New("UNIQUE constraint failed"), expected: true, }, { name: "duplicate key value", err: errors.New("duplicate key value"), expected: true, }, { name: "UNIQUE constraint case insensitive", err: errors.New("unique constraint violation"), expected: true, }, { name: "other error", err: errors.New("some other error"), expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isUniqueConstraintError(tt.err) assert.Equal(t, tt.expected, result) }) } } // ============ ProviderService - List MaskAPIKey 测试 ============ func TestProviderService_List_MaskAPIKey(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"} provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"} require.NoError(t, svc.Create(provider1)) require.NoError(t, svc.Create(provider2)) providers, err := svc.List() require.NoError(t, err) require.Len(t, providers, 2) for _, p := range providers { assert.Contains(t, p.APIKey, "***") assert.Len(t, p.APIKey, 7) } } func TestModelService_ConcurrentCreate(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})) results := make(chan error, 2) for i := 0; i < 2; i++ { go func() { model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} results <- svc.Create(model) }() } err1 := <-results err2 := <-results successCount := 0 errorCount := 0 for _, err := range []error{err1, err2} { if err == nil { successCount++ } else { errorCount++ } } assert.Equal(t, 1, successCount) assert.Equal(t, 1, errorCount) }