package service import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/driver/sqlite" "gorm.io/gorm" "nex/backend/internal/config" "nex/backend/internal/domain" "nex/backend/internal/repository" ) func setupServiceTestDB(t *testing.T) *gorm.DB { t.Helper() dir := t.TempDir() db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) require.NoError(t, err) err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) require.NoError(t, err) t.Cleanup(func() { sqlDB, _ := db.DB() if sqlDB != nil { sqlDB.Close() } }) return db } // ============ ProviderService 测试 ============ func TestProviderService_Create(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) svc := NewProviderService(repo) provider := &domain.Provider{ ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", } err := svc.Create(provider) require.NoError(t, err) assert.True(t, provider.Enabled) } func TestProviderService_Get_MaskKey(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) svc := NewProviderService(repo) svc.Create(&domain.Provider{ ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com", }) result, err := svc.Get("p1", true) require.NoError(t, err) assert.Equal(t, "***2345", result.APIKey) result, err = svc.Get("p1", false) require.NoError(t, err) assert.Equal(t, "sk-long-api-key-12345", result.APIKey) } func TestProviderService_List(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) svc := NewProviderService(repo) svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"}) svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"}) providers, err := svc.List() require.NoError(t, err) assert.Len(t, providers, 2) assert.Contains(t, providers[0].APIKey, "***") } func TestProviderService_Delete(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) svc := NewProviderService(repo) svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}) err := svc.Delete("p1") require.NoError(t, err) _, err = svc.Get("p1", false) assert.Error(t, err) } // ============ ModelService 测试 ============ func TestModelService_Create(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"} err := svc.Create(model) require.NoError(t, err) assert.True(t, model.Enabled) } func TestModelService_Create_ProviderNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"} err := svc.Create(model) assert.Error(t, err) } func TestModelService_List(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}) models, err := svc.List("p1") require.NoError(t, err) assert.Len(t, models, 2) } // ============ RoutingService 测试 ============ func TestRoutingService_Route(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) result, err := svc.Route("gpt-4") require.NoError(t, err) assert.Equal(t, "p1", result.Provider.ID) assert.Equal(t, "gpt-4", result.Model.ModelName) } func TestRoutingService_Route_ModelNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) _, err := svc.Route("nonexistent-model") assert.Error(t, err) } func TestRoutingService_Route_ModelDisabled(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) // 先创建启用的模型,然后通过 Update 禁用 modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) modelRepo.Update("m1", map[string]interface{}{"enabled": false}) _, err := svc.Route("gpt-4") assert.Error(t, err) } func TestRoutingService_Route_ProviderDisabled(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) // 先创建启用的 provider,然后禁用 providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) providerRepo.Update("p1", map[string]interface{}{"enabled": false}) modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) _, err := svc.Route("gpt-4") assert.Error(t, err) } // ============ StatsService 测试 ============ func TestStatsService_RecordAndGet(t *testing.T) { db := setupServiceTestDB(t) statsRepo := repository.NewStatsRepository(db) svc := NewStatsService(statsRepo) err := svc.Record("p1", "gpt-4") require.NoError(t, err) stats, err := svc.Get("p1", "", nil, nil) require.NoError(t, err) assert.Len(t, stats, 1) } func TestStatsService_Aggregate_ByProvider(t *testing.T) { statsRepo := repository.NewStatsRepository(nil) svc := NewStatsService(statsRepo) stats := []domain.UsageStats{ {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, {ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5}, {ProviderID: "p2", ModelName: "claude-3", RequestCount: 8}, } result := svc.Aggregate(stats, "provider") assert.Len(t, result, 2) p1Count := 0 p2Count := 0 for _, r := range result { if r["provider_id"] == "p1" { p1Count = r["request_count"].(int) } if r["provider_id"] == "p2" { p2Count = r["request_count"].(int) } } assert.Equal(t, 15, p1Count) assert.Equal(t, 8, p2Count) } func TestStatsService_Aggregate_ByDate(t *testing.T) { statsRepo := repository.NewStatsRepository(nil) svc := NewStatsService(statsRepo) stats := []domain.UsageStats{ {ProviderID: "p1", RequestCount: 10}, {ProviderID: "p2", RequestCount: 5}, } result := svc.Aggregate(stats, "date") assert.Len(t, result, 1) assert.Equal(t, 15, result[0]["request_count"]) }