package repository import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" testHelpers "nex/backend/tests" "nex/backend/internal/domain" ) func setupTestDB(t *testing.T) *gorm.DB { t.Helper() return testHelpers.SetupTestDB(t) } // ============ ProviderRepository 测试 ============ func TestProviderRepository_Create(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) provider := &domain.Provider{ ID: "test-provider", Name: "Test Provider", APIKey: "sk-test-key", BaseURL: "https://api.test.com", Enabled: true, } err := repo.Create(provider) require.NoError(t, err) } func TestProviderRepository_GetByID(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) provider := &domain.Provider{ ID: "test-provider", Name: "Test", APIKey: "sk-test-key", BaseURL: "https://api.test.com", } err := repo.Create(provider) require.NoError(t, err) result, err := repo.GetByID("test-provider") require.NoError(t, err) assert.Equal(t, "test-provider", result.ID) assert.Equal(t, "Test", result.Name) } func TestProviderRepository_GetByID_NotFound(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) _, err := repo.GetByID("nonexistent") assert.Error(t, err) } func TestProviderRepository_List(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) for _, id := range []string{"pA", "pB", "pC"} { err := repo.Create(&domain.Provider{ID: id, Name: id, APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, err) } providers, err := repo.List() require.NoError(t, err) assert.Len(t, providers, 3) } func TestProviderRepository_Update(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})) err := repo.Update("p1", map[string]interface{}{"name": "New"}) require.NoError(t, err) result, _ := repo.GetByID("p1") assert.Equal(t, "New", result.Name) } func TestProviderRepository_Update_NotFound(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) err := repo.Update("nonexistent", map[string]interface{}{"name": "New"}) assert.Error(t, err) } func TestProviderRepository_Delete(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) err := repo.Delete("p1") require.NoError(t, err) _, err = repo.GetByID("p1") assert.Error(t, err) } func TestProviderRepository_Delete_NotFound(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) err := repo.Delete("nonexistent") assert.Error(t, err) } // ============ ModelRepository 测试 ============ func TestModelRepository_Create(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) repo := NewModelRepository(db) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) require.NoError(t, err) } func TestModelRepository_GetByID(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) repo := NewModelRepository(db) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})) result, err := repo.GetByID("m1") require.NoError(t, err) assert.Equal(t, "m1", result.ID) assert.Equal(t, "gpt-4", result.ModelName) } func TestModelRepository_FindByProviderAndModelName(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) repo := NewModelRepository(db) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})) result, err := repo.FindByProviderAndModelName("p1", "gpt-4") require.NoError(t, err) assert.Equal(t, "m1", result.ID) assert.Equal(t, "p1", result.ProviderID) assert.Equal(t, "gpt-4", result.ModelName) } func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) repo := NewModelRepository(db) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})) // Wrong provider_id _, err := repo.FindByProviderAndModelName("p2", "gpt-4") assert.Error(t, err) // Wrong model_name _, err = repo.FindByProviderAndModelName("p1", "gpt-3.5") assert.Error(t, err) // Both wrong _, err = repo.FindByProviderAndModelName("p2", "claude-3") assert.Error(t, err) } func TestModelRepository_List(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) repo := NewModelRepository(db) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p2", Name: "Test2", APIKey: "key", BaseURL: "https://test2.com"})) require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})) require.NoError(t, repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})) require.NoError(t, repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})) all, err := repo.List("") require.NoError(t, err) assert.Len(t, all, 3) p1Models, err := repo.List("p1") require.NoError(t, err) assert.Len(t, p1Models, 2) } func TestModelRepository_ListEnabled(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) modelRepo := NewModelRepository(db) // Create two providers (both start enabled due to gorm:"default:true") err := providerRepo.Create(&domain.Provider{ ID: "enabled-provider", Name: "Enabled Provider", APIKey: "key1", BaseURL: "https://enabled.com", Enabled: true, }) require.NoError(t, err) err = providerRepo.Create(&domain.Provider{ ID: "disabled-provider", Name: "Disabled Provider", APIKey: "key2", BaseURL: "https://disabled.com", Enabled: true, }) require.NoError(t, err) // Disable the second provider via Update (GORM default:true skips zero values on Create) err = providerRepo.Update("disabled-provider", map[string]interface{}{"enabled": false}) require.NoError(t, err) // Create models (all start enabled due to gorm:"default:true") err = modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "enabled-provider", ModelName: "gpt-4", Enabled: true}) require.NoError(t, err) err = modelRepo.Create(&domain.Model{ID: "m2", ProviderID: "enabled-provider", ModelName: "gpt-3.5", Enabled: true}) require.NoError(t, err) err = modelRepo.Create(&domain.Model{ID: "m3", ProviderID: "disabled-provider", ModelName: "claude-3", Enabled: true}) require.NoError(t, err) err = modelRepo.Create(&domain.Model{ID: "m4", ProviderID: "disabled-provider", ModelName: "claude-3.5", Enabled: true}) require.NoError(t, err) // Disable m2 via Update err = modelRepo.Update("m2", map[string]interface{}{"enabled": false}) require.NoError(t, err) // ListEnabled should only return models where both model and provider are enabled: // - m1: enabled provider + enabled model -> returned // - m2: enabled provider + disabled model -> filtered out // - m3: disabled provider + enabled model -> filtered out // - m4: disabled provider + enabled model -> filtered out enabled, err := modelRepo.ListEnabled() require.NoError(t, err) require.Len(t, enabled, 1) assert.Equal(t, "m1", enabled[0].ID) assert.Equal(t, "enabled-provider", enabled[0].ProviderID) assert.Equal(t, "gpt-4", enabled[0].ModelName) } func TestModelRepository_Update(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) repo := NewModelRepository(db) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})) err := repo.Update("m1", map[string]interface{}{"enabled": false}) require.NoError(t, err) result, _ := repo.GetByID("m1") assert.False(t, result.Enabled) } func TestModelRepository_Delete(t *testing.T) { db := setupTestDB(t) providerRepo := NewProviderRepository(db) repo := NewModelRepository(db) require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})) require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})) err := repo.Delete("m1") require.NoError(t, err) _, err = repo.GetByID("m1") assert.Error(t, err) } // ============ StatsRepository 测试 ============ func TestStatsRepository_Record(t *testing.T) { db := setupTestDB(t) repo := NewStatsRepository(db) err := repo.Record("provider-1", "gpt-4") require.NoError(t, err) // 再次记录应递增 err = repo.Record("provider-1", "gpt-4") require.NoError(t, err) stats, err := repo.Query("provider-1", "", nil, nil) require.NoError(t, err) require.Len(t, stats, 1) assert.Equal(t, 2, stats[0].RequestCount) } func TestStatsRepository_Query(t *testing.T) { db := setupTestDB(t) repo := NewStatsRepository(db) require.NoError(t, repo.Record("p1", "gpt-4")) // 注意:当前 schema 只有 date 字段有唯一约束 // 所以同一 provider + model 只能有一条记录 stats, err := repo.Query("p1", "", nil, nil) require.NoError(t, err) assert.Len(t, stats, 1) } func TestModelRepository_List_EmptyResult(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) result, err := repo.List("") require.NoError(t, err) assert.NotNil(t, result) assert.Empty(t, result) assert.Len(t, result, 0) } func TestProviderRepository_List_EmptyResult(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) result, err := repo.List() require.NoError(t, err) assert.NotNil(t, result) assert.Empty(t, result) assert.Len(t, result, 0) }