1
0
Files
nex/backend/tests/integration/integration_test.go
lanyuanxiaoyao f18904af1e feat: 实现分层架构,包含 domain、service、repository 和 pkg 层
- 新增 domain 层:model、provider、route、stats 实体
- 新增 service 层:models、providers、routing、stats 业务逻辑
- 新增 repository 层:models、providers、stats 数据访问
- 新增 pkg 工具包:errors、logger、validator
- 新增中间件:CORS、logging、recovery、request ID
- 新增数据库迁移:初始 schema 和索引
- 新增单元测试和集成测试
- 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage
- 移除 config 子包和 model_router(已迁移至分层架构)
2026-04-16 00:47:20 +08:00

264 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package integration
import (
"bytes"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/repository"
"nex/backend/internal/service"
)
func init() {
gin.SetMode(gin.TestMode)
}
func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
t.Helper()
dir := t.TempDir()
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
_ = service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
r := gin.New()
r.Use(middleware.CORS())
providers := r.Group("/api/providers")
{
providers.GET("", providerHandler.ListProviders)
providers.POST("", providerHandler.CreateProvider)
providers.GET("/:id", providerHandler.GetProvider)
providers.PUT("/:id", providerHandler.UpdateProvider)
providers.DELETE("/:id", providerHandler.DeleteProvider)
}
models := r.Group("/api/models")
{
models.GET("", modelHandler.ListModels)
models.POST("", modelHandler.CreateModel)
models.GET("/:id", modelHandler.GetModel)
models.PUT("/:id", modelHandler.UpdateModel)
models.DELETE("/:id", modelHandler.DeleteModel)
}
stats := r.Group("/api/stats")
{
stats.GET("", statsHandler.GetStats)
stats.GET("/aggregate", statsHandler.AggregateStats)
}
return r, db
}
func TestOpenAI_CompleteFlow(t *testing.T) {
r, _ := setupIntegrationTest(t)
// 1. 创建 Provider
providerBody, _ := json.Marshal(map[string]string{
"id": "openai", "name": "OpenAI", "api_key": "sk-test-key", "base_url": "https://api.openai.com/v1",
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
// 2. 创建 Model
modelBody, _ := json.Marshal(map[string]string{
"id": "gpt4", "provider_id": "openai", "model_name": "gpt-4",
})
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
// 3. 列出 Provider
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/providers", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var providers []domain.Provider
json.Unmarshal(w.Body.Bytes(), &providers)
assert.Len(t, providers, 1)
assert.Contains(t, providers[0].APIKey, "***") // 已掩码
// 4. 列出 Model
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/models?provider_id=openai", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var models []domain.Model
json.Unmarshal(w.Body.Bytes(), &models)
assert.Len(t, models, 1)
assert.Equal(t, "gpt-4", models[0].ModelName)
// 5. 更新 Provider
updateBody, _ := json.Marshal(map[string]string{"name": "OpenAI Updated"})
w = httptest.NewRecorder()
req = httptest.NewRequest("PUT", "/api/providers/openai", bytes.NewReader(updateBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
// 6. 删除 Model
w = httptest.NewRecorder()
req = httptest.NewRequest("DELETE", "/api/models/gpt4", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code)
// 7. 删除 Provider
w = httptest.NewRecorder()
req = httptest.NewRequest("DELETE", "/api/providers/openai", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code)
}
func TestAnthropic_ModelCreation(t *testing.T) {
r, _ := setupIntegrationTest(t)
// 创建 Provider 和 Model 用于 Anthropic 代理
providerBody, _ := json.Marshal(map[string]string{
"id": "anthropic", "name": "Anthropic", "api_key": "sk-ant-test", "base_url": "https://api.anthropic.com/v1",
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{
"id": "claude3", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229",
})
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
// 验证创建成功
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/models/claude3", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestStats_RecordingAndQuery(t *testing.T) {
r, db := setupIntegrationTest(t)
// 创建 Provider 和 Model
providerBody, _ := json.Marshal(map[string]string{
"id": "p1", "name": "Provider1", "api_key": "key", "base_url": "https://test.com",
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
modelBody, _ := json.Marshal(map[string]string{
"id": "m1", "provider_id": "p1", "model_name": "gpt-4",
})
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
// 直接通过 repository 记录统计(模拟代理请求后的统计记录)
statsRepo := repository.NewStatsRepository(db)
statsRepo.Record("p1", "gpt-4")
statsRepo.Record("p1", "gpt-4")
statsRepo.Record("p1", "gpt-4")
// 查询统计
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/stats?provider_id=p1", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var stats []domain.UsageStats
json.Unmarshal(w.Body.Bytes(), &stats)
assert.Len(t, stats, 1)
assert.Equal(t, 3, stats[0].RequestCount)
// 聚合统计
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestProvider_DuplicateCreation(t *testing.T) {
r, _ := setupIntegrationTest(t)
providerBody, _ := json.Marshal(map[string]string{
"id": "p1", "name": "P1", "api_key": "key", "base_url": "https://test.com",
})
// 第一次创建成功
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
// 第二次创建应失败UNIQUE 约束)
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 409, w.Code)
}
func TestProvider_NotFound(t *testing.T) {
r, _ := setupIntegrationTest(t)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/api/providers/nonexistent", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
}
func TestStats_InvalidDate(t *testing.T) {
r, _ := setupIntegrationTest(t)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/api/stats?start_date=not-a-date", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 400, w.Code)
}
// Suppress unused import warning
var _ = time.Second