package integration import ( "bytes" "encoding/json" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/driver/sqlite" "gorm.io/gorm" "nex/backend/internal/config" "nex/backend/internal/domain" "nex/backend/internal/handler" "nex/backend/internal/handler/middleware" "nex/backend/internal/repository" "nex/backend/internal/service" ) func init() { gin.SetMode(gin.TestMode) } func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) { t.Helper() dir := t.TempDir() db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) require.NoError(t, err) err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) require.NoError(t, err) t.Cleanup(func() { sqlDB, _ := db.DB() if sqlDB != nil { sqlDB.Close() } }) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) providerService := service.NewProviderService(providerRepo) modelService := service.NewModelService(modelRepo, providerRepo) _ = service.NewRoutingService(modelRepo, providerRepo) statsService := service.NewStatsService(statsRepo) providerHandler := handler.NewProviderHandler(providerService) modelHandler := handler.NewModelHandler(modelService) statsHandler := handler.NewStatsHandler(statsService) r := gin.New() r.Use(middleware.CORS()) providers := r.Group("/api/providers") { providers.GET("", providerHandler.ListProviders) providers.POST("", providerHandler.CreateProvider) providers.GET("/:id", providerHandler.GetProvider) providers.PUT("/:id", providerHandler.UpdateProvider) providers.DELETE("/:id", providerHandler.DeleteProvider) } models := r.Group("/api/models") { models.GET("", modelHandler.ListModels) models.POST("", modelHandler.CreateModel) models.GET("/:id", modelHandler.GetModel) models.PUT("/:id", modelHandler.UpdateModel) models.DELETE("/:id", modelHandler.DeleteModel) } stats := r.Group("/api/stats") { stats.GET("", statsHandler.GetStats) stats.GET("/aggregate", statsHandler.AggregateStats) } return r, db } func TestOpenAI_CompleteFlow(t *testing.T) { r, _ := setupIntegrationTest(t) // 1. 创建 Provider providerBody, _ := json.Marshal(map[string]string{ "id": "openai", "name": "OpenAI", "api_key": "sk-test-key", "base_url": "https://api.openai.com/v1", }) w := httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) // 2. 创建 Model modelBody, _ := json.Marshal(map[string]string{ "id": "gpt4", "provider_id": "openai", "model_name": "gpt-4", }) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) // 3. 列出 Provider w = httptest.NewRecorder() req = httptest.NewRequest("GET", "/api/providers", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) var providers []domain.Provider json.Unmarshal(w.Body.Bytes(), &providers) assert.Len(t, providers, 1) assert.Contains(t, providers[0].APIKey, "***") // 已掩码 // 4. 列出 Model w = httptest.NewRecorder() req = httptest.NewRequest("GET", "/api/models?provider_id=openai", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) var models []domain.Model json.Unmarshal(w.Body.Bytes(), &models) assert.Len(t, models, 1) assert.Equal(t, "gpt-4", models[0].ModelName) // 5. 更新 Provider updateBody, _ := json.Marshal(map[string]string{"name": "OpenAI Updated"}) w = httptest.NewRecorder() req = httptest.NewRequest("PUT", "/api/providers/openai", bytes.NewReader(updateBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) // 6. 删除 Model w = httptest.NewRecorder() req = httptest.NewRequest("DELETE", "/api/models/gpt4", nil) r.ServeHTTP(w, req) assert.Equal(t, 204, w.Code) // 7. 删除 Provider w = httptest.NewRecorder() req = httptest.NewRequest("DELETE", "/api/providers/openai", nil) r.ServeHTTP(w, req) assert.Equal(t, 204, w.Code) } func TestAnthropic_ModelCreation(t *testing.T) { r, _ := setupIntegrationTest(t) // 创建 Provider 和 Model 用于 Anthropic 代理 providerBody, _ := json.Marshal(map[string]string{ "id": "anthropic", "name": "Anthropic", "api_key": "sk-ant-test", "base_url": "https://api.anthropic.com/v1", }) w := httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) modelBody, _ := json.Marshal(map[string]string{ "id": "claude3", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229", }) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) // 验证创建成功 w = httptest.NewRecorder() req = httptest.NewRequest("GET", "/api/models/claude3", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestStats_RecordingAndQuery(t *testing.T) { r, db := setupIntegrationTest(t) // 创建 Provider 和 Model providerBody, _ := json.Marshal(map[string]string{ "id": "p1", "name": "Provider1", "api_key": "key", "base_url": "https://test.com", }) w := httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) modelBody, _ := json.Marshal(map[string]string{ "id": "m1", "provider_id": "p1", "model_name": "gpt-4", }) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) // 直接通过 repository 记录统计(模拟代理请求后的统计记录) statsRepo := repository.NewStatsRepository(db) statsRepo.Record("p1", "gpt-4") statsRepo.Record("p1", "gpt-4") statsRepo.Record("p1", "gpt-4") // 查询统计 w = httptest.NewRecorder() req = httptest.NewRequest("GET", "/api/stats?provider_id=p1", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) var stats []domain.UsageStats json.Unmarshal(w.Body.Bytes(), &stats) assert.Len(t, stats, 1) assert.Equal(t, 3, stats[0].RequestCount) // 聚合统计 w = httptest.NewRecorder() req = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestProvider_DuplicateCreation(t *testing.T) { r, _ := setupIntegrationTest(t) providerBody, _ := json.Marshal(map[string]string{ "id": "p1", "name": "P1", "api_key": "key", "base_url": "https://test.com", }) // 第一次创建成功 w := httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) // 第二次创建应失败(UNIQUE 约束) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 409, w.Code) } func TestProvider_NotFound(t *testing.T) { r, _ := setupIntegrationTest(t) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/api/providers/nonexistent", nil) r.ServeHTTP(w, req) assert.Equal(t, 404, w.Code) } func TestStats_InvalidDate(t *testing.T) { r, _ := setupIntegrationTest(t) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/api/stats?start_date=not-a-date", nil) r.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) } // Suppress unused import warning var _ = time.Second