package repository import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/driver/sqlite" "gorm.io/gorm" "nex/backend/internal/config" "nex/backend/internal/domain" ) func setupTestDB(t *testing.T) *gorm.DB { t.Helper() dir := t.TempDir() db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) require.NoError(t, err) err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) require.NoError(t, err) // 关闭数据库连接以便 TempDir 清理 t.Cleanup(func() { sqlDB, _ := db.DB() if sqlDB != nil { sqlDB.Close() } }) return db } // ============ ProviderRepository 测试 ============ func TestProviderRepository_Create(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) provider := &domain.Provider{ ID: "test-provider", Name: "Test Provider", APIKey: "sk-test-key", BaseURL: "https://api.test.com", Enabled: true, } err := repo.Create(provider) require.NoError(t, err) } func TestProviderRepository_GetByID(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) provider := &domain.Provider{ ID: "test-provider", Name: "Test", APIKey: "sk-test-key", BaseURL: "https://api.test.com", } err := repo.Create(provider) require.NoError(t, err) result, err := repo.GetByID("test-provider") require.NoError(t, err) assert.Equal(t, "test-provider", result.ID) assert.Equal(t, "Test", result.Name) } func TestProviderRepository_GetByID_NotFound(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) _, err := repo.GetByID("nonexistent") assert.Error(t, err) } func TestProviderRepository_List(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) for _, id := range []string{"pA", "pB", "pC"} { err := repo.Create(&domain.Provider{ID: id, Name: id, APIKey: "key", BaseURL: "https://test.com"}) require.NoError(t, err) } providers, err := repo.List() require.NoError(t, err) assert.Len(t, providers, 3) } func TestProviderRepository_Update(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}) err := repo.Update("p1", map[string]interface{}{"name": "New"}) require.NoError(t, err) result, _ := repo.GetByID("p1") assert.Equal(t, "New", result.Name) } func TestProviderRepository_Update_NotFound(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) err := repo.Update("nonexistent", map[string]interface{}{"name": "New"}) assert.Error(t, err) } func TestProviderRepository_Delete(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}) err := repo.Delete("p1") require.NoError(t, err) _, err = repo.GetByID("p1") assert.Error(t, err) } func TestProviderRepository_Delete_NotFound(t *testing.T) { db := setupTestDB(t) repo := NewProviderRepository(db) err := repo.Delete("nonexistent") assert.Error(t, err) } // ============ ModelRepository 测试 ============ func TestModelRepository_Create(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) require.NoError(t, err) } func TestModelRepository_GetByID(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) result, err := repo.GetByID("m1") require.NoError(t, err) assert.Equal(t, "m1", result.ID) assert.Equal(t, "gpt-4", result.ModelName) } func TestModelRepository_GetByModelName(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) result, err := repo.GetByModelName("gpt-4") require.NoError(t, err) assert.Equal(t, "m1", result.ID) } func TestModelRepository_List(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}) repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}) all, err := repo.List("") require.NoError(t, err) assert.Len(t, all, 3) p1Models, err := repo.List("p1") require.NoError(t, err) assert.Len(t, p1Models, 2) } func TestModelRepository_Update(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) err := repo.Update("m1", map[string]interface{}{"enabled": false}) require.NoError(t, err) result, _ := repo.GetByID("m1") assert.False(t, result.Enabled) } func TestModelRepository_Delete(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) err := repo.Delete("m1") require.NoError(t, err) _, err = repo.GetByID("m1") assert.Error(t, err) } // ============ StatsRepository 测试 ============ func TestStatsRepository_Record(t *testing.T) { db := setupTestDB(t) repo := NewStatsRepository(db) err := repo.Record("provider-1", "gpt-4") require.NoError(t, err) // 再次记录应递增 err = repo.Record("provider-1", "gpt-4") require.NoError(t, err) stats, err := repo.Query("provider-1", "", nil, nil) require.NoError(t, err) require.Len(t, stats, 1) assert.Equal(t, 2, stats[0].RequestCount) } func TestStatsRepository_Query(t *testing.T) { db := setupTestDB(t) repo := NewStatsRepository(db) repo.Record("p1", "gpt-4") // 注意:当前 schema 只有 date 字段有唯一约束 // 所以同一 provider + model 只能有一条记录 stats, err := repo.Query("p1", "", nil, nil) require.NoError(t, err) assert.Len(t, stats, 1) }