1
0

feat: 实现统一模型 ID 机制

实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
2026-04-21 18:14:10 +08:00
parent 7f0f831226
commit 395887667d
73 changed files with 3360 additions and 1374 deletions

View File

@@ -113,7 +113,6 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{
"id": "m1",
"provider_id": "p1",
"model_name": "gpt-4",
})
@@ -127,7 +126,7 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
var result domain.Model
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
assert.NotEmpty(t, result.ID)
}
func TestModelHandler_GetModel(t *testing.T) {

View File

@@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
@@ -31,7 +30,7 @@ type mockRoutingService struct {
err error
}
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
@@ -57,6 +56,14 @@ type mockProviderService struct {
err error
}
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err
@@ -73,13 +80,21 @@ type mockModelService struct {
err error
}
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
func (m *mockModelService) Create(model *domain.Model) error {
if m.err == nil {
model.ID = "mock-uuid-1234"
}
return m.err
}
func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err
}
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err
}
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
return []domain.Model{}, nil
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err
}
@@ -163,8 +178,8 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{
models: []domain.Model{
{ID: "m1", ModelName: "gpt-4"},
{ID: "m2", ModelName: "gpt-3.5"},
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
},
})
@@ -174,6 +189,72 @@ func TestModelHandler_ListModels(t *testing.T) {
h.ListModels(c)
assert.Equal(t, 200, w.Code)
var result []modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
require.Len(t, result, 2)
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
}
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
h.GetModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 201, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "mock-uuid-1234", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"},
})
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
}
// ============ Stats Handler 测试 ============
@@ -256,7 +337,7 @@ func formatMapErrors(errs map[string]string) string {
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
err: gorm.ErrDuplicatedKey,
err: appErrors.ErrConflict,
})
body, _ := json.Marshal(map[string]string{

View File

@@ -1,6 +1,7 @@
package handler
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
@@ -22,23 +23,35 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
return &ModelHandler{modelService: modelService}
}
// modelResponse 模型响应 DTO扩展 unified_id 字段
type modelResponse struct {
domain.Model
UnifiedModelID string `json:"unified_id"`
}
// newModelResponse 从 domain.Model 构造响应 DTO
func newModelResponse(m *domain.Model) modelResponse {
return modelResponse{
Model: *m,
UnifiedModelID: m.UnifiedModelID(),
}
}
// CreateModel 创建模型
func (h *ModelHandler) CreateModel(c *gin.Context) {
var req struct {
ID string `json:"id" binding:"required"`
ProviderID string `json:"provider_id" binding:"required"`
ModelName string `json:"model_name" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, provider_id, model_name",
"error": "缺少必需字段: provider_id, model_name",
})
return
}
model := &domain.Model{
ID: req.ID,
ProviderID: req.ProviderID,
ModelName: req.ModelName,
}
@@ -51,11 +64,18 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
})
return
}
if err == appErrors.ErrDuplicateModel {
c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err)
return
}
c.JSON(http.StatusCreated, model)
c.JSON(http.StatusCreated, newModelResponse(model))
}
// ListModels 列出模型
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
return
}
c.JSON(http.StatusOK, models)
resp := make([]modelResponse, len(models))
for i, m := range models {
resp[i] = newModelResponse(&m)
}
c.JSON(http.StatusOK, resp)
}
// GetModel 获取模型
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
return
}
c.JSON(http.StatusOK, model)
c.JSON(http.StatusOK, newModelResponse(model))
}
// UpdateModel 更新模型
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
err := h.modelService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, appErrors.ErrModelNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
if err == appErrors.ErrProviderNotFound {
if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{
"error": appErrors.ErrDuplicateModel.Message,
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err)
return
}
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
return
}
c.JSON(http.StatusOK, model)
c.JSON(http.StatusOK, newModelResponse(model))
}
// DeleteModel 删除模型

View File

@@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider)
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
if err == appErrors.ErrInvalidProviderID {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code,
})
return
}
@@ -119,6 +120,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
})
return
}
if errors.Is(err, appErrors.ErrImmutableField) {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrImmutableField.Message,
"code": appErrors.ErrImmutableField.Code,
})
return
}
writeError(c, err)
return
}

View File

@@ -11,9 +11,11 @@ import (
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
"nex/backend/pkg/modelid"
)
// ProxyHandler 统一代理处理器
@@ -54,6 +56,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
nativePath := "/v1/" + path
// 获取 client adapter
registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
return
}
// 检测接口类型
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
// 处理 Models 接口:本地聚合
if ifaceType == conversion.InterfaceTypeModels {
h.handleModelsList(c, clientAdapter)
return
}
// 处理 ModelInfo 接口:本地查询
if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
return
}
h.handleModelInfo(c, unifiedID, clientAdapter)
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
@@ -61,10 +91,17 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
return
}
// 解析 model 名称(从 JSON body 中提取GET 请求无 body
modelName := ""
// 解析统一模型 ID使用 adapter.ExtractModelName
var providerID, modelName string
if len(body) > 0 {
modelName = extractModelName(body)
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
}
// 构建输入 HTTPRequestSpec
@@ -76,7 +113,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
// 路由
routeResult, err := h.routingService.Route(modelName)
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
@@ -94,24 +131,30 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
// 构建 TargetProvider
// 注意ModelName 字段用于 Smart Passthrough 场景改写请求体
// 同协议:请求体中的统一 ID 会被改写为 ModelName上游名
// 跨协议:全量转换时 ModelName 会被编码到请求体中
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName,
routeResult.Model.ModelName, // 上游模型名,用于请求改写
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID()
if isStream {
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
} else {
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
}
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
@@ -128,9 +171,8 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
return
}
// 转换响应
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
@@ -153,7 +195,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
}
// handleStream 处理流式请求
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
@@ -161,8 +203,8 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return
}
// 创建流式转换器
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
@@ -224,6 +266,79 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
return req.Stream
}
// handleModelsList 处理 GET /v1/models 本地聚合
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
// 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return
}
// 构建 CanonicalModelList
modelList := &canonical.CanonicalModelList{
Models: make([]canonical.CanonicalModel, 0, len(models)),
}
for _, m := range models {
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
ID: m.UnifiedModelID(),
Name: m.ModelName,
Created: m.CreatedAt.Unix(),
OwnedBy: m.ProviderID,
})
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
// 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
return
}
// 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
return
}
// 构建 CanonicalModelInfo
modelInfo := &canonical.CanonicalModelInfo{
ID: model.UnifiedModelID(),
Name: model.ModelName,
Created: model.CreatedAt.Unix(),
OwnedBy: model.ProviderID,
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
@@ -292,7 +407,7 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
return
}
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
@@ -307,17 +422,6 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
}
// extractModelName 从 JSON body 中提取 model
func extractModelName(body []byte) string {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
return ""
}
return req.Model
}
// extractHeaders 从 Gin context 提取请求头
func extractHeaders(c *gin.Context) map[string]string {
headers := make(map[string]string)

View File

@@ -60,13 +60,23 @@ type mockProxyRoutingService struct {
err error
}
func (m *mockProxyRoutingService) Route(modelName string) (*domain.RouteResult, error) {
func (m *mockProxyRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockProxyProviderService struct {
providers []domain.Provider
err error
providers []domain.Provider
err error
enabledModels []domain.Model
modelByProvName *domain.Model
}
func (m *mockProxyProviderService) ListEnabledModels() ([]domain.Model, error) {
return m.enabledModels, nil
}
func (m *mockProxyProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return m.modelByProvName, nil
}
func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil }
@@ -319,7 +329,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
// Models 接口现在本地聚合,返回空列表 200
assert.Equal(t, 200, w.Code)
}
func TestExtractHeaders(t *testing.T) {
@@ -716,58 +727,6 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
assert.Equal(t, 200, w.Code)
}
// ============ extractModelName 测试 ============
func TestExtractModelName(t *testing.T) {
tests := []struct {
name string
body []byte
expected string
}{
{
name: "valid model",
body: []byte(`{"model": "gpt-4", "messages": []}`),
expected: "gpt-4",
},
{
name: "empty body",
body: []byte(`{}`),
expected: "",
},
{
name: "invalid json",
body: []byte(`{invalid}`),
expected: "",
},
{
name: "nested structure",
body: []byte(`{"model": "claude-3", "messages": [{"role": "user", "content": "hello"}]}`),
expected: "claude-3",
},
{
name: "model with special chars",
body: []byte(`{"model": "gpt-4-0125-preview", "stream": true}`),
expected: "gpt-4-0125-preview",
},
{
name: "empty body bytes",
body: []byte{},
expected: "",
},
{
name: "model is null",
body: []byte(`{"model": null}`),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractModelName(tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ isStreamRequest 测试 ============
func TestIsStreamRequest(t *testing.T) {
@@ -831,3 +790,270 @@ func TestIsStreamRequest(t *testing.T) {
})
}
}
// ============ Models / ModelInfo 本地聚合测试 ============
func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
enabledModels: []domain.Model{
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3", Enabled: true},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data, ok := resp["data"].([]interface{})
require.True(t, ok)
assert.Len(t, data, 2)
// 验证统一模型 ID 格式
first := data[0].(map[string]interface{})
assert.Equal(t, "openai/gpt-4", first["id"])
}
func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
modelByProvName: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai/gpt-4", resp["id"])
}
func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"object":"list","data":[]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, &mockProxyRoutingService{err: appErrors.ErrModelNotFound}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
// ============ Smart Passthrough 统一模型 ID 路由测试 ============
func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
// 验证请求体中的 model 已被改写为上游模型名
var req map[string]interface{}
json.Unmarshal(spec.Body, &req)
assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 客户端发送统一模型 ID
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证响应中的 model 已被改写为统一模型 ID
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai_p/gpt-4", resp["model"])
}
// ============ 跨协议统一模型 ID 路由测试 ============
func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[{"type":"text","text":"Hello"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":3}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// OpenAI 客户端使用统一模型 ID 路由到 Anthropic 供应商
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证跨协议转换后响应中的 model 被覆写为统一模型 ID
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "anthropic_p/claude-3", resp["model"])
}
func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte(`event: message_start
data: {"type":"message_start","message":{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[]}}
`)}
ch <- provider.StreamEvent{Data: []byte(`event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
`)}
ch <- provider.StreamEvent{Data: []byte(`event: message_stop
data: {"type":"message_stop"}
`)}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
body := w.Body.String()
// 验证跨协议流式中 model 被覆写为统一模型 ID
assert.Contains(t, body, "anthropic_p/claude-3", "跨协议流式响应中 model 应被覆写为统一模型 ID")
}
func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
},
}
var capturedRequestBody []byte
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
capturedRequestBody = spec.Body
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"unknown_field":"preserved"}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 包含未知参数,验证 Smart Passthrough 保真性
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证请求中 model 被改写为上游模型名,但未知参数保留
var reqBody map[string]interface{}
require.NoError(t, json.Unmarshal(capturedRequestBody, &reqBody))
assert.Equal(t, "gpt-4", reqBody["model"], "请求中 model 应被改写为上游模型名")
assert.Equal(t, "should_be_preserved", reqBody["custom_param"], "Smart Passthrough 应保留未知参数")
// 验证响应中 model 被改写为统一模型 ID但未知参数保留
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai_p/gpt-4", resp["model"], "响应中 model 应被改写为统一模型 ID")
assert.Equal(t, "preserved", resp["unknown_field"], "Smart Passthrough 应保留未知响应字段")
}
func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 使用统一模型 ID 格式但模型不存在
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "error")
}