1
0
Files
nex/backend/tests/integration/integration_test.go
lanyuanxiaoyao 395887667d feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
2026-04-21 18:14:10 +08:00

256 lines
7.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package integration
import (
"bytes"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"nex/backend/internal/domain"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/repository"
"nex/backend/internal/service"
)
func init() {
gin.SetMode(gin.TestMode)
}
func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
t.Helper()
db := setupTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
_ = service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
r := gin.New()
r.Use(middleware.CORS())
providers := r.Group("/api/providers")
{
providers.GET("", providerHandler.ListProviders)
providers.POST("", providerHandler.CreateProvider)
providers.GET("/:id", providerHandler.GetProvider)
providers.PUT("/:id", providerHandler.UpdateProvider)
providers.DELETE("/:id", providerHandler.DeleteProvider)
}
models := r.Group("/api/models")
{
models.GET("", modelHandler.ListModels)
models.POST("", modelHandler.CreateModel)
models.GET("/:id", modelHandler.GetModel)
models.PUT("/:id", modelHandler.UpdateModel)
models.DELETE("/:id", modelHandler.DeleteModel)
}
stats := r.Group("/api/stats")
{
stats.GET("", statsHandler.GetStats)
stats.GET("/aggregate", statsHandler.AggregateStats)
}
return r, db
}
func TestOpenAI_CompleteFlow(t *testing.T) {
r, _ := setupIntegrationTest(t)
// 1. 创建 Provider
providerBody, _ := json.Marshal(map[string]string{
"id": "openai", "name": "OpenAI", "api_key": "sk-test-key", "base_url": "https://api.openai.com/v1",
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
// 2. 创建 Model
modelBody, _ := json.Marshal(map[string]string{
"provider_id": "openai", "model_name": "gpt-4",
})
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel)
assert.NotEmpty(t, createdModel.ID)
// 3. 列出 Provider
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/providers", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var providers []domain.Provider
json.Unmarshal(w.Body.Bytes(), &providers)
assert.Len(t, providers, 1)
assert.Contains(t, providers[0].APIKey, "***") // 已掩码
// 4. 列出 Model
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/models?provider_id=openai", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var models []domain.Model
json.Unmarshal(w.Body.Bytes(), &models)
assert.Len(t, models, 1)
assert.Equal(t, "gpt-4", models[0].ModelName)
// 5. 更新 Provider
updateBody, _ := json.Marshal(map[string]string{"name": "OpenAI Updated"})
w = httptest.NewRecorder()
req = httptest.NewRequest("PUT", "/api/providers/openai", bytes.NewReader(updateBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
// 6. 删除 Model
w = httptest.NewRecorder()
req = httptest.NewRequest("DELETE", "/api/models/"+createdModel.ID, nil)
r.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code)
// 7. 删除 Provider
w = httptest.NewRecorder()
req = httptest.NewRequest("DELETE", "/api/providers/openai", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code)
}
func TestAnthropic_ModelCreation(t *testing.T) {
r, _ := setupIntegrationTest(t)
// 创建 Provider 和 Model 用于 Anthropic 代理
providerBody, _ := json.Marshal(map[string]string{
"id": "anthropic", "name": "Anthropic", "api_key": "sk-ant-test", "base_url": "https://api.anthropic.com/v1",
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{
"provider_id": "anthropic", "model_name": "claude-3-opus-20240229",
})
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel)
// 验证创建成功
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/models/"+createdModel.ID, nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestStats_RecordingAndQuery(t *testing.T) {
r, db := setupIntegrationTest(t)
// 创建 Provider 和 Model
providerBody, _ := json.Marshal(map[string]string{
"id": "p1", "name": "Provider1", "api_key": "key", "base_url": "https://test.com",
})
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
modelBody, _ := json.Marshal(map[string]string{
"provider_id": "p1", "model_name": "gpt-4",
})
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
// 直接通过 repository 记录统计(模拟代理请求后的统计记录)
statsRepo := repository.NewStatsRepository(db)
statsRepo.Record("p1", "gpt-4")
statsRepo.Record("p1", "gpt-4")
statsRepo.Record("p1", "gpt-4")
// 查询统计
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/stats?provider_id=p1", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
var stats []domain.UsageStats
json.Unmarshal(w.Body.Bytes(), &stats)
assert.Len(t, stats, 1)
assert.Equal(t, 3, stats[0].RequestCount)
// 聚合统计
w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestProvider_DuplicateCreation(t *testing.T) {
r, _ := setupIntegrationTest(t)
providerBody, _ := json.Marshal(map[string]string{
"id": "p1", "name": "P1", "api_key": "key", "base_url": "https://test.com",
})
// 第一次创建成功
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
// 第二次创建应失败UNIQUE 约束)
w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
assert.Equal(t, 409, w.Code)
}
func TestProvider_NotFound(t *testing.T) {
r, _ := setupIntegrationTest(t)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/api/providers/nonexistent", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
}
func TestStats_InvalidDate(t *testing.T) {
r, _ := setupIntegrationTest(t)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/api/stats?start_date=not-a-date", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 400, w.Code)
}
// Suppress unused import warning
var _ = time.Second