package service import ( "errors" "testing" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/driver/sqlite" "gorm.io/gorm" "nex/backend/internal/config" "nex/backend/internal/domain" "nex/backend/internal/repository" appErrors "nex/backend/pkg/errors" ) func setupServiceTestDB(t *testing.T) *gorm.DB { t.Helper() dir := t.TempDir() db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) require.NoError(t, err) err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) require.NoError(t, err) t.Cleanup(func() { sqlDB, _ := db.DB() if sqlDB != nil { sqlDB.Close() } }) return db } // ============ RoutingService - RouteByModelName 测试 ============ func TestRoutingService_RouteByModelName_Success(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) // 创建供应商和模型 providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) result, err := svc.RouteByModelName("openai", "gpt-4") require.NoError(t, err) assert.Equal(t, "openai", result.Provider.ID) assert.Equal(t, "gpt-4", result.Model.ModelName) } func TestRoutingService_RouteByModelName_NotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) _, err := svc.RouteByModelName("openai", "nonexistent-model") assert.True(t, errors.Is(err, appErrors.ErrModelNotFound)) } func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) // 创建启用的供应商和禁用的模型 providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) modelRepo.Update("m1", map[string]interface{}{"enabled": false}) _, err := svc.RouteByModelName("openai", "gpt-4") assert.True(t, errors.Is(err, appErrors.ErrModelDisabled)) } func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewRoutingService(modelRepo, providerRepo) // 创建启用的供应商和模型,然后禁用供应商 providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) providerRepo.Update("openai", map[string]interface{}{"enabled": false}) _, err := svc.RouteByModelName("openai", "gpt-4") assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled)) } // ============ ModelService - Create with UUID 测试 ============ func TestModelService_Create_GeneratesUUID(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model) require.NoError(t, err) // 验证返回的 model 拥有有效的 UUID assert.NotEmpty(t, model.ID) _, err = uuid.Parse(model.ID) assert.NoError(t, err, "model.ID should be a valid UUID") // 通过 Get 验证持久化 stored, err := svc.Get(model.ID) require.NoError(t, err) assert.Equal(t, model.ID, stored.ID) assert.Equal(t, "gpt-4", stored.ModelName) } func TestModelService_Create_DuplicateModelName(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model1) require.NoError(t, err) // 使用相同的 (providerID, modelName) 创建第二个模型应失败 model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err = svc.Create(model2) assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel)) } func TestModelService_Create_ProviderNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"} err := svc.Create(model) assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound)) } // ============ ProviderService - Create with validation 测试 ============ func TestProviderService_Create_InvalidID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID)) } func TestProviderService_Create_ValidID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) require.NoError(t, err) assert.Equal(t, "openai", provider.ID) assert.True(t, provider.Enabled) } // ============ ModelService - Update with duplicate check 测试 ============ func TestModelService_Update_DuplicateModelName(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}) model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model1) require.NoError(t, err) model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"} err = svc.Create(model2) require.NoError(t, err) // 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突 err = svc.Update(model2.ID, map[string]interface{}{ "provider_id": "openai", "model_name": "gpt-4", }) assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel)) } func TestModelService_Update_ModelNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) err := svc.Update("nonexistent-id", map[string]interface{}{ "model_name": "gpt-4", }) assert.True(t, errors.Is(err, appErrors.ErrModelNotFound)) } func TestModelService_Update_Success(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model) require.NoError(t, err) // 更新 model_name 为不冲突的值 err = svc.Update(model.ID, map[string]interface{}{ "model_name": "gpt-4-turbo", }) require.NoError(t, err) updated, err := svc.Get(model.ID) require.NoError(t, err) assert.Equal(t, "gpt-4-turbo", updated.ModelName) } // ============ ProviderService - Update immutable ID 测试 ============ func TestProviderService_Update_ImmutableID(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) require.NoError(t, err) // 尝试更新 id 字段 err = svc.Update("openai", map[string]interface{}{ "id": "new-id", }) assert.True(t, errors.Is(err, appErrors.ErrImmutableField)) } func TestProviderService_Update_Success(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewProviderService(repo, modelRepo) provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} err := svc.Create(provider) require.NoError(t, err) // 更新 name err = svc.Update("openai", map[string]interface{}{ "name": "OpenAI Updated", }) require.NoError(t, err) updated, err := svc.Get("openai", false) require.NoError(t, err) assert.Equal(t, "OpenAI Updated", updated.Name) }