1
0

feat: 实现统一模型 ID 机制

实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
2026-04-21 18:14:10 +08:00
parent 7f0f831226
commit 395887667d
73 changed files with 3360 additions and 1374 deletions

View File

@@ -17,11 +17,11 @@ type Provider struct {
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
}
// Model 模型配置
// Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
type Model struct {
ID string `gorm:"primaryKey" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"`
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}

View File

@@ -40,6 +40,12 @@ type ProtocolAdapter interface {
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
// 统一模型 ID 相关方法
ExtractUnifiedModelID(nativePath string) (string, error)
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
}
// AdapterRegistry 适配器注册表接口

View File

@@ -2,6 +2,7 @@ package anthropic
import (
"encoding/json"
"fmt"
"strings"
"nex/backend/internal/conversion"
@@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
}
}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/")
return suffix != ""
}
// BuildUrl 根据接口类型构建 URL
@@ -203,3 +204,74 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
default:
return body, nil
}
}

View File

@@ -0,0 +1,263 @@
package anthropic
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ExtractUnifiedModelID
// ---------------------------------------------------------------------------
func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", id)
})
t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
require.NoError(t, err)
assert.Equal(t, "some/deep/nested/model", id)
})
t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
require.NoError(t, err)
assert.Equal(t, "claude-3", id)
})
t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/messages")
require.Error(t, err)
})
t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err)
})
t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
t.Run("unrelated_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/other")
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestExtractModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", model)
})
t.Run("chat_with_max_tokens", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3-opus", model)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
require.Error(t, err)
})
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteRequestModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestRewriteRequestModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "claude-3", m["model"])
msgs, ok := m["messages"]
require.True(t, ok)
msgsArr, ok := msgs.([]interface{})
require.True(t, ok)
assert.Len(t, msgsArr, 0)
})
t.Run("preserves_unknown_fields", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "claude-3", m["model"])
assert.Equal(t, 0.7, m["temperature"])
// max_tokens is encoded as float in JSON numbers
maxTokens, ok := m["max_tokens"]
require.True(t, ok)
assert.Equal(t, float64(1024), maxTokens)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteResponseModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestRewriteResponseModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat_existing_model", func(t *testing.T) {
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "anthropic/claude-3", m["model"])
// other fields preserved
_, hasContent := m["content"]
assert.True(t, hasContent)
assert.Equal(t, "end_turn", m["stop_reason"])
})
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "anthropic/claude-3", m["model"])
})
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
body := []byte(`{"model":"claude-3"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
require.NoError(t, err)
assert.Equal(t, string(body), string(rewritten))
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName and RewriteRequest consistency
// ---------------------------------------------------------------------------
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
a := NewAdapter()
t.Run("chat_round_trip", func(t *testing.T) {
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
// Extract the unified model ID from the body
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", extracted)
// Rewrite to the native model name
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
// Extract again from the rewritten body to verify the same location was targeted
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "claude-3", afterRewrite)
// Verify other fields are preserved
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, float64(1024), m["max_tokens"])
})
}
// ---------------------------------------------------------------------------
// isModelInfoPath (additional unified model ID cases)
// ---------------------------------------------------------------------------
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"simple_model_id", "/v1/models/claude-3", true},
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
{"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/v1/models/", false},
{"messages_path", "/v1/messages", false},
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}

View File

@@ -79,11 +79,29 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
if err != nil {
return nil, err
}
// Smart Passthrough: 同协议时最小化改写 model 字段
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
rewrittenBody := spec.Body
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
if len(spec.Body) > 0 && provider.ModelName != "" {
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
zap.String("error", err.Error()),
zap.String("interface", string(interfaceType)))
rewrittenBody = spec.Body
}
}
}
return &HTTPRequestSpec{
URL: provider.BaseURL + nativePath,
Method: spec.Method,
Headers: providerAdapter.BuildHeaders(provider),
Body: spec.Body,
Body: rewrittenBody,
}, nil
}
@@ -112,9 +130,30 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}, nil
}
// ConvertHttpResponse 转换 HTTP 响应
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
// ConvertHttpResponse 转换 HTTP 响应modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return &spec, nil
}
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.String("error", err.Error()),
zap.String("interface", string(interfaceType)))
return &spec, nil
}
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
return &spec, nil
}
@@ -127,7 +166,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
return nil, err
}
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
if err != nil {
return nil, err
}
@@ -139,9 +178,17 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
}, nil
}
// CreateStreamConverter 创建流式转换器
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
// CreateStreamConverter 创建流式转换器modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return NewPassthroughStreamConverter(), nil
}
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
return NewPassthroughStreamConverter(), nil
}
@@ -167,6 +214,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
ctx,
clientProtocol,
providerProtocol,
modelOverride,
), nil
}
@@ -192,11 +240,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte
}
}
// convertResponseBody 转换响应体
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
// convertResponseBody 转换响应体modelOverride 非空时在 canonical 层面覆写 Model 字段
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
switch interfaceType {
case InterfaceTypeChat:
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeModels:
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
return body, nil
@@ -211,12 +259,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
return body, nil
}
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeRerank:
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
return body, nil
}
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
default:
return body, nil
}
@@ -241,11 +289,14 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
return encoded, nil
}
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
canonicalResp, err := providerAdapter.DecodeResponse(body)
if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
}
if modelOverride != "" {
canonicalResp.Model = modelOverride
}
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
if err != nil {
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
@@ -290,12 +341,15 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
return providerAdapter.EncodeEmbeddingRequest(req, provider)
}
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeEmbeddingResponse(resp)
}
@@ -308,11 +362,14 @@ func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter Prot
return providerAdapter.EncodeRerankRequest(req, provider)
}
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body)
if err != nil {
return body, nil
}
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
}

View File

@@ -113,7 +113,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
}, "client", "provider", InterfaceTypeChat)
}, "client", "provider", InterfaceTypeChat, "")
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Contains(t, string(result.Body), "resp-1")
@@ -129,7 +129,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
_ = engine.RegisterAdapter(providerAdapter)
_ = engine.RegisterAdapter(newMockAdapter("client", false))
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat)
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
assert.Error(t, err)
}
@@ -189,7 +189,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeEmbeddings)
}, "client", "provider", InterfaceTypeEmbeddings, "")
require.NoError(t, err)
assert.NotNil(t, result)
}
@@ -207,7 +207,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeRerank)
}, "client", "provider", InterfaceTypeRerank, "")
require.NoError(t, err)
assert.NotNil(t, result)
}
@@ -242,7 +242,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
}, "client", "provider", InterfaceTypeModels)
}, "client", "provider", InterfaceTypeModels, "")
require.NoError(t, err)
assert.NotNil(t, result)
}
@@ -259,7 +259,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
}, "client", "provider", InterfaceTypeModelInfo)
}, "client", "provider", InterfaceTypeModelInfo, "")
require.NoError(t, err)
assert.NotNil(t, result)
}

View File

@@ -13,16 +13,18 @@ import (
// mockProtocolAdapter 模拟协议适配器
type mockProtocolAdapter struct {
protocolName string
passthrough bool
ifaceType InterfaceType
supportsIface map[InterfaceType]bool
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
streamDecoderFn func() StreamDecoder
streamEncoderFn func() StreamEncoder
protocolName string
passthrough bool
ifaceType InterfaceType
supportsIface map[InterfaceType]bool
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
streamDecoderFn func() StreamDecoder
streamEncoderFn func() StreamEncoder
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
}
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
@@ -155,6 +157,28 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera
return json.Marshal(resp)
}
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
return "", nil
}
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
return "", nil
}
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
if m.rewriteReqFn != nil {
return m.rewriteReqFn(body, newModel, ifaceType)
}
return body, nil
}
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
if m.rewriteRespFn != nil {
return m.rewriteRespFn(body, newModel, ifaceType)
}
return body, nil
}
// noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{}
@@ -309,7 +333,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
Body: []byte(`{"id":"123"}`),
}
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat)
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Equal(t, spec.Body, result.Body)
@@ -320,7 +344,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
converter, err := engine.CreateStreamConverter("openai", "openai")
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
require.NoError(t, err)
_, ok := converter.(*PassthroughStreamConverter)
assert.True(t, ok)
@@ -332,7 +356,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
_ = engine.RegisterAdapter(newMockAdapter("client", false))
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
converter, err := engine.CreateStreamConverter("client", "provider")
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
require.NoError(t, err)
_, ok := converter.(*CanonicalStreamConverter)
assert.True(t, ok)
@@ -380,3 +404,230 @@ func TestRegistry_GetNonExistent(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "未找到适配器")
}
// ============ modelOverride 测试 ============
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
return json.Marshal(map[string]any{"model": resp.Model})
}
_ = engine.RegisterAdapter(clientAdapter)
providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
}
_ = engine.RegisterAdapter(providerAdapter)
spec := HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"model":"native-model"}`),
}
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
require.NoError(t, err)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(result.Body, &resp))
assert.Equal(t, "provider/gpt-4", resp["model"])
}
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
_ = engine.RegisterAdapter(openaiAdapter)
spec := HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
}
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
require.NoError(t, err)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(result.Body, &resp))
assert.Equal(t, "openai/gpt-4", resp["model"])
assert.Equal(t, "resp-1", resp["id"])
}
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
_ = engine.RegisterAdapter(openaiAdapter)
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
require.NoError(t, err)
_, ok := converter.(*SmartPassthroughStreamConverter)
assert.True(t, ok)
// 验证 chunk 改写
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
require.Len(t, chunks, 1)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(chunks[0], &resp))
assert.Equal(t, "openai/gpt-4", resp["model"])
}
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
// provider adapter 解码出含 model 的流式事件
providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder {
return &engineTestStreamDecoder{
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
return []canonical.CanonicalStreamEvent{
canonical.NewMessageStartEvent("msg-1", "native-model"),
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
canonical.NewMessageStopEvent(),
}
},
}
}
_ = engine.RegisterAdapter(providerAdapter)
// client adapter 编码时输出 model 字段
clientAdapter := newMockAdapter("client", false)
clientAdapter.streamEncoderFn = func() StreamEncoder {
return &engineTestStreamEncoder{
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
if event.Message != nil {
data, _ := json.Marshal(map[string]string{
"type": string(event.Type),
"model": event.Message.Model,
})
return [][]byte{data}
}
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
return [][]byte{data}
},
}
}
_ = engine.RegisterAdapter(clientAdapter)
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
require.NoError(t, err)
// 验证类型是 CanonicalStreamConverter
_, ok := converter.(*CanonicalStreamConverter)
assert.True(t, ok)
// 处理一个 chunk验证 model 被覆写为统一模型 ID
chunks := converter.ProcessChunk([]byte("raw"))
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
var startEvent map[string]string
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
}
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder {
return &engineTestStreamDecoder{
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
return []canonical.CanonicalStreamEvent{
canonical.NewMessageStartEvent("msg-1", "native-model"),
}
},
}
}
_ = engine.RegisterAdapter(providerAdapter)
clientAdapter := newMockAdapter("client", false)
clientAdapter.streamEncoderFn = func() StreamEncoder {
return &engineTestStreamEncoder{
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
if event.Message != nil {
data, _ := json.Marshal(map[string]string{
"model": event.Message.Model,
})
return [][]byte{data}
}
return nil
},
}
}
_ = engine.RegisterAdapter(clientAdapter)
// modelOverride 为空,不应覆写
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
require.NoError(t, err)
chunks := converter.ProcessChunk([]byte("raw"))
require.Len(t, chunks, 1)
var resp map[string]string
require.NoError(t, json.Unmarshal(chunks[0], &resp))
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
}
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test
type engineTestStreamDecoder struct {
processFn func([]byte) []canonical.CanonicalStreamEvent
flushFn func() []canonical.CanonicalStreamEvent
}
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
if d.processFn != nil {
return d.processFn(raw)
}
return nil
}
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil {
return d.flushFn()
}
return nil
}
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test
type engineTestStreamEncoder struct {
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
flushFn func() [][]byte
}
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
if e.encodeFn != nil {
return e.encodeFn(event)
}
return nil
}
func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil {
return e.flushFn()
}
return nil
}

View File

@@ -2,6 +2,7 @@ package openai
import (
"encoding/json"
"fmt"
"strings"
"nex/backend/internal/conversion"
@@ -43,13 +44,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
}
}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/")
return suffix != ""
}
// BuildUrl 根据接口类型构建 URL
@@ -216,3 +217,80 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return encodeRerankResponse(resp)
}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel)
}
return json.Marshal(m)
default:
return body, nil
}
}

View File

@@ -121,7 +121,7 @@ func TestIsModelInfoPath(t *testing.T) {
{"model_info", "/v1/models/gpt-4", true},
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
{"models_list", "/v1/models", false},
{"nested_path", "/v1/models/gpt-4/versions", false},
{"nested_path", "/v1/models/gpt-4/versions", true},
{"empty_suffix", "/v1/models/", false},
{"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/v1/model", false},

View File

@@ -0,0 +1,360 @@
package openai
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ExtractUnifiedModelID
// ---------------------------------------------------------------------------
func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", id)
})
t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
})
t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", id)
})
t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
require.Error(t, err)
})
t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err)
})
t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
t.Run("unrelated_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/other")
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName
// ---------------------------------------------------------------------------
func TestExtractModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", model)
})
t.Run("embedding", func(t *testing.T) {
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "openai/text-embedding", model)
})
t.Run("rerank", func(t *testing.T) {
body := []byte(`{"model":"openai/rerank","query":"test"}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "openai/rerank", model)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteRequestModelName
// ---------------------------------------------------------------------------
func TestRewriteRequestModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "gpt-4", m["model"])
// messages field preserved
msgs, ok := m["messages"]
require.True(t, ok)
msgsArr, ok := msgs.([]interface{})
require.True(t, ok)
assert.Len(t, msgsArr, 0)
})
t.Run("preserves_unknown_fields", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "gpt-4", m["model"])
assert.Equal(t, 0.7, m["temperature"])
})
t.Run("embedding", func(t *testing.T) {
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "text-embedding", m["model"])
assert.Equal(t, "hello", m["input"])
})
t.Run("rerank", func(t *testing.T) {
body := []byte(`{"model":"openai/rerank","query":"test"}`)
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "rerank", m["model"])
assert.Equal(t, "test", m["query"])
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4"}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteResponseModelName
// ---------------------------------------------------------------------------
func TestRewriteResponseModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat_existing_model", func(t *testing.T) {
body := []byte(`{"model":"gpt-4","choices":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/gpt-4", m["model"])
choices, ok := m["choices"]
require.True(t, ok)
choicesArr, ok := choices.([]interface{})
require.True(t, ok)
assert.Len(t, choicesArr, 0)
})
t.Run("chat_without_model_field", func(t *testing.T) {
body := []byte(`{"choices":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/gpt-4", m["model"])
choices, ok := m["choices"]
require.True(t, ok)
choicesArr, ok := choices.([]interface{})
require.True(t, ok)
assert.Len(t, choicesArr, 0)
})
t.Run("rerank_existing_model", func(t *testing.T) {
body := []byte(`{"model":"rerank","results":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/rerank", m["model"])
})
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
body := []byte(`{"results":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
_, hasModel := m["model"]
assert.False(t, hasModel, "rerank response without model field should not have one added")
})
t.Run("embedding_existing_model", func(t *testing.T) {
body := []byte(`{"model":"text-embedding","data":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/text-embedding", m["model"])
})
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
body := []byte(`{"data":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/text-embedding", m["model"])
})
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
body := []byte(`{"model":"gpt-4"}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
require.NoError(t, err)
assert.Equal(t, string(body), string(rewritten))
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName and RewriteRequest consistency
// ---------------------------------------------------------------------------
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
a := NewAdapter()
t.Run("chat_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
// Extract the unified model ID from the body
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", extracted)
// Rewrite to the native model name
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
// Extract again from the rewritten body to verify the same location was targeted
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "gpt-4", afterRewrite)
// Verify other fields are preserved
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, 0.7, m["temperature"])
})
t.Run("embedding_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "openai/text-embedding", extracted)
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "text-embedding", afterRewrite)
})
t.Run("rerank_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/rerank","query":"test"}`)
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "openai/rerank", extracted)
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "rerank", afterRewrite)
})
}
// ---------------------------------------------------------------------------
// isModelInfoPath (additional unified model ID cases)
// ---------------------------------------------------------------------------
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"simple_model_id", "/v1/models/gpt-4", true},
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
{"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/v1/models/", false},
{"chat_completions", "/v1/chat/completions", false},
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}

View File

@@ -38,14 +38,52 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
return nil
}
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
// 逐 chunk 改写 model 字段
type SmartPassthroughStreamConverter struct {
adapter ProtocolAdapter
modelOverride string
interfaceType InterfaceType
}
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
return &SmartPassthroughStreamConverter{
adapter: adapter,
modelOverride: modelOverride,
interfaceType: interfaceType,
}
}
// ProcessChunk 改写 chunk 中的 model 字段
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
if len(rawChunk) == 0 {
return nil
}
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
if err != nil {
// 改写失败,返回原始 chunk
return [][]byte{rawChunk}
}
return [][]byte{rewrittenChunk}
}
// Flush 无缓冲数据
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
return nil
}
// CanonicalStreamConverter 跨协议规范流式转换器
type CanonicalStreamConverter struct {
decoder StreamDecoder
encoder StreamEncoder
chain *MiddlewareChain
ctx ConversionContext
clientProtocol string
decoder StreamDecoder
encoder StreamEncoder
chain *MiddlewareChain
ctx ConversionContext
clientProtocol string
providerProtocol string
modelOverride string
}
// NewCanonicalStreamConverter 创建规范流式转换器
@@ -57,18 +95,19 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *
}
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter {
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
return &CanonicalStreamConverter{
decoder: decoder,
encoder: encoder,
chain: chain,
ctx: ctx,
clientProtocol: clientProtocol,
decoder: decoder,
encoder: encoder,
chain: chain,
ctx: ctx,
clientProtocol: clientProtocol,
providerProtocol: providerProtocol,
modelOverride: modelOverride,
}
}
// ProcessChunk 解码 → 中间件 → 编码管道
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
events := c.decoder.ProcessChunk(rawChunk)
var result [][]byte
@@ -80,6 +119,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
}
events[i] = *processed
}
c.applyModelOverride(&events[i])
chunks := c.encoder.EncodeEvent(events[i])
result = append(result, chunks...)
}
@@ -98,6 +138,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
}
events[i] = *processed
}
c.applyModelOverride(&events[i])
chunks := c.encoder.EncodeEvent(events[i])
result = append(result, chunks...)
}
@@ -105,3 +146,10 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
result = append(result, encoderChunks...)
return result
}
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
if c.modelOverride != "" && event.Message != nil {
event.Message.Model = c.modelOverride
}
}

View File

@@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.ProcessChunk([]byte("raw"))
assert.Len(t, result, 1)
@@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.ProcessChunk([]byte("raw"))
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
@@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.Flush()
assert.Len(t, result, 1)

View File

@@ -1,8 +1,12 @@
package domain
import "time"
import (
"time"
// Model 模型领域模型
"nex/backend/pkg/modelid"
)
// Model 模型领域模型id 为 UUID 自动生成)
type Model struct {
ID string `json:"id"`
ProviderID string `json:"provider_id"`
@@ -10,3 +14,8 @@ type Model struct {
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}
// UnifiedModelID 返回统一模型 ID格式provider_id/model_name
func (m *Model) UnifiedModelID() string {
return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName)
}

View File

@@ -113,7 +113,6 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{
"id": "m1",
"provider_id": "p1",
"model_name": "gpt-4",
})
@@ -127,7 +126,7 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
var result domain.Model
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
assert.NotEmpty(t, result.ID)
}
func TestModelHandler_GetModel(t *testing.T) {

View File

@@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
@@ -31,7 +30,7 @@ type mockRoutingService struct {
err error
}
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
@@ -57,6 +56,14 @@ type mockProviderService struct {
err error
}
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err
@@ -73,13 +80,21 @@ type mockModelService struct {
err error
}
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
func (m *mockModelService) Create(model *domain.Model) error {
if m.err == nil {
model.ID = "mock-uuid-1234"
}
return m.err
}
func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err
}
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err
}
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
return []domain.Model{}, nil
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err
}
@@ -163,8 +178,8 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{
models: []domain.Model{
{ID: "m1", ModelName: "gpt-4"},
{ID: "m2", ModelName: "gpt-3.5"},
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
},
})
@@ -174,6 +189,72 @@ func TestModelHandler_ListModels(t *testing.T) {
h.ListModels(c)
assert.Equal(t, 200, w.Code)
var result []modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
require.Len(t, result, 2)
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
}
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
h.GetModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 201, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "mock-uuid-1234", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"},
})
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
}
// ============ Stats Handler 测试 ============
@@ -256,7 +337,7 @@ func formatMapErrors(errs map[string]string) string {
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
err: gorm.ErrDuplicatedKey,
err: appErrors.ErrConflict,
})
body, _ := json.Marshal(map[string]string{

View File

@@ -1,6 +1,7 @@
package handler
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
@@ -22,23 +23,35 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
return &ModelHandler{modelService: modelService}
}
// modelResponse 模型响应 DTO扩展 unified_id 字段
type modelResponse struct {
domain.Model
UnifiedModelID string `json:"unified_id"`
}
// newModelResponse 从 domain.Model 构造响应 DTO
func newModelResponse(m *domain.Model) modelResponse {
return modelResponse{
Model: *m,
UnifiedModelID: m.UnifiedModelID(),
}
}
// CreateModel 创建模型
func (h *ModelHandler) CreateModel(c *gin.Context) {
var req struct {
ID string `json:"id" binding:"required"`
ProviderID string `json:"provider_id" binding:"required"`
ModelName string `json:"model_name" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, provider_id, model_name",
"error": "缺少必需字段: provider_id, model_name",
})
return
}
model := &domain.Model{
ID: req.ID,
ProviderID: req.ProviderID,
ModelName: req.ModelName,
}
@@ -51,11 +64,18 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
})
return
}
if err == appErrors.ErrDuplicateModel {
c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err)
return
}
c.JSON(http.StatusCreated, model)
c.JSON(http.StatusCreated, newModelResponse(model))
}
// ListModels 列出模型
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
return
}
c.JSON(http.StatusOK, models)
resp := make([]modelResponse, len(models))
for i, m := range models {
resp[i] = newModelResponse(&m)
}
c.JSON(http.StatusOK, resp)
}
// GetModel 获取模型
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
return
}
c.JSON(http.StatusOK, model)
c.JSON(http.StatusOK, newModelResponse(model))
}
// UpdateModel 更新模型
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
err := h.modelService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, appErrors.ErrModelNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
return
}
if err == appErrors.ErrProviderNotFound {
if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{
"error": appErrors.ErrDuplicateModel.Message,
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err)
return
}
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
return
}
c.JSON(http.StatusOK, model)
c.JSON(http.StatusOK, newModelResponse(model))
}
// DeleteModel 删除模型

View File

@@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider)
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
if err == appErrors.ErrInvalidProviderID {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code,
})
return
}
@@ -119,6 +120,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
})
return
}
if errors.Is(err, appErrors.ErrImmutableField) {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrImmutableField.Message,
"code": appErrors.ErrImmutableField.Code,
})
return
}
writeError(c, err)
return
}

View File

@@ -11,9 +11,11 @@ import (
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
"nex/backend/pkg/modelid"
)
// ProxyHandler 统一代理处理器
@@ -54,6 +56,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
nativePath := "/v1/" + path
// 获取 client adapter
registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
return
}
// 检测接口类型
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
// 处理 Models 接口:本地聚合
if ifaceType == conversion.InterfaceTypeModels {
h.handleModelsList(c, clientAdapter)
return
}
// 处理 ModelInfo 接口:本地查询
if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
return
}
h.handleModelInfo(c, unifiedID, clientAdapter)
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
@@ -61,10 +91,17 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
return
}
// 解析 model 名称(从 JSON body 中提取GET 请求无 body
modelName := ""
// 解析统一模型 ID使用 adapter.ExtractModelName
var providerID, modelName string
if len(body) > 0 {
modelName = extractModelName(body)
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
}
// 构建输入 HTTPRequestSpec
@@ -76,7 +113,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
// 路由
routeResult, err := h.routingService.Route(modelName)
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
@@ -94,24 +131,30 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
// 构建 TargetProvider
// 注意ModelName 字段用于 Smart Passthrough 场景改写请求体
// 同协议:请求体中的统一 ID 会被改写为 ModelName上游名
// 跨协议:全量转换时 ModelName 会被编码到请求体中
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName,
routeResult.Model.ModelName, // 上游模型名,用于请求改写
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID()
if isStream {
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
} else {
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
}
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
@@ -128,9 +171,8 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
return
}
// 转换响应
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
@@ -153,7 +195,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
}
// handleStream 处理流式请求
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
@@ -161,8 +203,8 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return
}
// 创建流式转换器
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
@@ -224,6 +266,79 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
return req.Stream
}
// handleModelsList 处理 GET /v1/models 本地聚合
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
// 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return
}
// 构建 CanonicalModelList
modelList := &canonical.CanonicalModelList{
Models: make([]canonical.CanonicalModel, 0, len(models)),
}
for _, m := range models {
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
ID: m.UnifiedModelID(),
Name: m.ModelName,
Created: m.CreatedAt.Unix(),
OwnedBy: m.ProviderID,
})
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
// 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
return
}
// 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
return
}
// 构建 CanonicalModelInfo
modelInfo := &canonical.CanonicalModelInfo{
ID: model.UnifiedModelID(),
Name: model.ModelName,
Created: model.CreatedAt.Unix(),
OwnedBy: model.ProviderID,
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
@@ -292,7 +407,7 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
return
}
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
@@ -307,17 +422,6 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
}
// extractModelName 从 JSON body 中提取 model
func extractModelName(body []byte) string {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
return ""
}
return req.Model
}
// extractHeaders 从 Gin context 提取请求头
func extractHeaders(c *gin.Context) map[string]string {
headers := make(map[string]string)

View File

@@ -60,13 +60,23 @@ type mockProxyRoutingService struct {
err error
}
func (m *mockProxyRoutingService) Route(modelName string) (*domain.RouteResult, error) {
func (m *mockProxyRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockProxyProviderService struct {
providers []domain.Provider
err error
providers []domain.Provider
err error
enabledModels []domain.Model
modelByProvName *domain.Model
}
func (m *mockProxyProviderService) ListEnabledModels() ([]domain.Model, error) {
return m.enabledModels, nil
}
func (m *mockProxyProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return m.modelByProvName, nil
}
func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil }
@@ -319,7 +329,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
// Models 接口现在本地聚合,返回空列表 200
assert.Equal(t, 200, w.Code)
}
func TestExtractHeaders(t *testing.T) {
@@ -716,58 +727,6 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
assert.Equal(t, 200, w.Code)
}
// ============ extractModelName 测试 ============
func TestExtractModelName(t *testing.T) {
tests := []struct {
name string
body []byte
expected string
}{
{
name: "valid model",
body: []byte(`{"model": "gpt-4", "messages": []}`),
expected: "gpt-4",
},
{
name: "empty body",
body: []byte(`{}`),
expected: "",
},
{
name: "invalid json",
body: []byte(`{invalid}`),
expected: "",
},
{
name: "nested structure",
body: []byte(`{"model": "claude-3", "messages": [{"role": "user", "content": "hello"}]}`),
expected: "claude-3",
},
{
name: "model with special chars",
body: []byte(`{"model": "gpt-4-0125-preview", "stream": true}`),
expected: "gpt-4-0125-preview",
},
{
name: "empty body bytes",
body: []byte{},
expected: "",
},
{
name: "model is null",
body: []byte(`{"model": null}`),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractModelName(tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ isStreamRequest 测试 ============
func TestIsStreamRequest(t *testing.T) {
@@ -831,3 +790,270 @@ func TestIsStreamRequest(t *testing.T) {
})
}
}
// ============ Models / ModelInfo 本地聚合测试 ============
func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
enabledModels: []domain.Model{
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3", Enabled: true},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data, ok := resp["data"].([]interface{})
require.True(t, ok)
assert.Len(t, data, 2)
// 验证统一模型 ID 格式
first := data[0].(map[string]interface{})
assert.Equal(t, "openai/gpt-4", first["id"])
}
func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
modelByProvName: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai/gpt-4", resp["id"])
}
func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"object":"list","data":[]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, &mockProxyRoutingService{err: appErrors.ErrModelNotFound}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
// ============ Smart Passthrough 统一模型 ID 路由测试 ============
func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
// 验证请求体中的 model 已被改写为上游模型名
var req map[string]interface{}
json.Unmarshal(spec.Body, &req)
assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 客户端发送统一模型 ID
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证响应中的 model 已被改写为统一模型 ID
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai_p/gpt-4", resp["model"])
}
// ============ 跨协议统一模型 ID 路由测试 ============
func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[{"type":"text","text":"Hello"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":3}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// OpenAI 客户端使用统一模型 ID 路由到 Anthropic 供应商
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证跨协议转换后响应中的 model 被覆写为统一模型 ID
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "anthropic_p/claude-3", resp["model"])
}
func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte(`event: message_start
data: {"type":"message_start","message":{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[]}}
`)}
ch <- provider.StreamEvent{Data: []byte(`event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
`)}
ch <- provider.StreamEvent{Data: []byte(`event: message_stop
data: {"type":"message_stop"}
`)}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
body := w.Body.String()
// 验证跨协议流式中 model 被覆写为统一模型 ID
assert.Contains(t, body, "anthropic_p/claude-3", "跨协议流式响应中 model 应被覆写为统一模型 ID")
}
func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
},
}
var capturedRequestBody []byte
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
capturedRequestBody = spec.Body
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"unknown_field":"preserved"}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 包含未知参数,验证 Smart Passthrough 保真性
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证请求中 model 被改写为上游模型名,但未知参数保留
var reqBody map[string]interface{}
require.NoError(t, json.Unmarshal(capturedRequestBody, &reqBody))
assert.Equal(t, "gpt-4", reqBody["model"], "请求中 model 应被改写为上游模型名")
assert.Equal(t, "should_be_preserved", reqBody["custom_param"], "Smart Passthrough 应保留未知参数")
// 验证响应中 model 被改写为统一模型 ID但未知参数保留
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai_p/gpt-4", resp["model"], "响应中 model 应被改写为统一模型 ID")
assert.Equal(t, "preserved", resp["unknown_field"], "Smart Passthrough 应保留未知响应字段")
}
func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 使用统一模型 ID 格式但模型不存在
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "error")
}

View File

@@ -7,7 +7,8 @@ type ModelRepository interface {
Create(model *domain.Model) error
GetByID(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error)
GetByModelName(modelName string) (*domain.Model, error)
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
ListEnabled() ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -52,9 +52,9 @@ func (r *modelRepository) List(providerID string) ([]domain.Model, error) {
return result, nil
}
func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) {
func (r *modelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
var m config.Model
err := r.db.Where("model_name = ?", modelName).First(&m).Error
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&m).Error
if err != nil {
return nil, err
}
@@ -62,6 +62,21 @@ func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error
return &d, nil
}
func (r *modelRepository) ListEnabled() ([]domain.Model, error) {
var models []config.Model
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
Where("models.enabled = ? AND providers.enabled = ?", true, true).
Find(&models).Error
if err != nil {
return nil, err
}
result := make([]domain.Model, len(models))
for i := range models {
result[i] = toDomainModel(&models[i])
}
return result, nil
}
func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {

View File

@@ -9,4 +9,7 @@ type ProviderRepository interface {
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
}

View File

@@ -71,6 +71,25 @@ func (r *providerRepository) Delete(id string) error {
return nil
}
// ListEnabledModels 返回所有启用的模型(关联启用的供应商)
func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) {
var models []domain.Model
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
Where("models.enabled = ? AND providers.enabled = ?", true, true).
Find(&models).Error
return models, err
}
// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型
func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
var model domain.Model
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
func toDomainProvider(p *config.Provider) domain.Provider {
return domain.Provider{
ID: p.ID,

View File

@@ -147,15 +147,36 @@ func TestModelRepository_GetByID(t *testing.T) {
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelRepository_GetByModelName(t *testing.T) {
func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
result, err := repo.GetByModelName("gpt-4")
result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "m1", result.ID)
assert.Equal(t, "p1", result.ProviderID)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
// Wrong provider_id
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
assert.Error(t, err)
// Wrong model_name
_, err = repo.FindByProviderAndModelName("p1", "gpt-3.5")
assert.Error(t, err)
// Both wrong
_, err = repo.FindByProviderAndModelName("p2", "claude-3")
assert.Error(t, err)
}
func TestModelRepository_List(t *testing.T) {
@@ -175,6 +196,54 @@ func TestModelRepository_List(t *testing.T) {
assert.Len(t, p1Models, 2)
}
func TestModelRepository_ListEnabled(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
modelRepo := NewModelRepository(db)
// Create two providers (both start enabled due to gorm:"default:true")
err := providerRepo.Create(&domain.Provider{
ID: "enabled-provider", Name: "Enabled Provider",
APIKey: "key1", BaseURL: "https://enabled.com", Enabled: true,
})
require.NoError(t, err)
err = providerRepo.Create(&domain.Provider{
ID: "disabled-provider", Name: "Disabled Provider",
APIKey: "key2", BaseURL: "https://disabled.com", Enabled: true,
})
require.NoError(t, err)
// Disable the second provider via Update (GORM default:true skips zero values on Create)
err = providerRepo.Update("disabled-provider", map[string]interface{}{"enabled": false})
require.NoError(t, err)
// Create models (all start enabled due to gorm:"default:true")
err = modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "enabled-provider", ModelName: "gpt-4", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m2", ProviderID: "enabled-provider", ModelName: "gpt-3.5", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m3", ProviderID: "disabled-provider", ModelName: "claude-3", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m4", ProviderID: "disabled-provider", ModelName: "claude-3.5", Enabled: true})
require.NoError(t, err)
// Disable m2 via Update
err = modelRepo.Update("m2", map[string]interface{}{"enabled": false})
require.NoError(t, err)
// ListEnabled should only return models where both model and provider are enabled:
// - m1: enabled provider + enabled model -> returned
// - m2: enabled provider + disabled model -> filtered out
// - m3: disabled provider + enabled model -> filtered out
// - m4: disabled provider + enabled model -> filtered out
enabled, err := modelRepo.ListEnabled()
require.NoError(t, err)
require.Len(t, enabled, 1)
assert.Equal(t, "m1", enabled[0].ID)
assert.Equal(t, "enabled-provider", enabled[0].ProviderID)
assert.Equal(t, "gpt-4", enabled[0].ModelName)
}
func TestModelRepository_Update(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)

View File

@@ -7,6 +7,7 @@ type ModelService interface {
Create(model *domain.Model) error
Get(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error)
ListEnabled() ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -1,6 +1,7 @@
package service
import (
"github.com/google/uuid"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
@@ -17,11 +18,18 @@ func NewModelService(modelRepo repository.ModelRepository, providerRepo reposito
}
func (s *modelService) Create(model *domain.Model) error {
// Verify provider exists
_, err := s.providerRepo.GetByID(model.ProviderID)
if err != nil {
// 校验供应商存在
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
return appErrors.ErrProviderNotFound
}
// 联合唯一校验:同一供应商下 model_name 不重复
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
return err
}
// 自动生成 UUID 作为 id
model.ID = uuid.New().String()
model.Enabled = true
return s.modelRepo.Create(model)
}
@@ -34,17 +42,57 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) {
return s.modelRepo.List(providerID)
}
func (s *modelService) ListEnabled() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
func (s *modelService) Update(id string, updates map[string]interface{}) error {
// If updating provider_id, verify new provider exists
// 获取当前模型
current, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
// 如果更新 provider_id校验新供应商存在
if providerID, ok := updates["provider_id"].(string); ok {
_, err := s.providerRepo.GetByID(providerID)
if err != nil {
if _, err := s.providerRepo.GetByID(providerID); err != nil {
return appErrors.ErrProviderNotFound
}
}
// 确定更新后的 provider_id 和 model_name
newProviderID := current.ProviderID
if v, ok := updates["provider_id"].(string); ok {
newProviderID = v
}
newModelName := current.ModelName
if v, ok := updates["model_name"].(string); ok {
newModelName = v
}
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
if newProviderID != current.ProviderID || newModelName != current.ModelName {
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
return err
}
}
return s.modelRepo.Update(id, updates)
}
func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id)
}
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
// excludeID 用于更新时排除自身
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil // 未找到,不重复
}
if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身
}
return appErrors.ErrDuplicateModel
}

View File

@@ -9,4 +9,7 @@ type ProviderService interface {
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error)
}

View File

@@ -1,21 +1,35 @@
package service
import (
"strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
)
type providerService struct {
providerRepo repository.ProviderRepository
modelRepo repository.ModelRepository
}
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
return &providerService{providerRepo: providerRepo}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
}
func (s *providerService) Create(provider *domain.Provider) error {
// 校验 provider_id 字符集
if err := modelid.ValidateProviderID(provider.ID); err != nil {
return appErrors.ErrInvalidProviderID
}
provider.Enabled = true
return s.providerRepo.Create(provider)
err := s.providerRepo.Create(provider)
if err != nil && isUniqueConstraintError(err) {
return appErrors.ErrConflict
}
return err
}
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
@@ -41,9 +55,31 @@ func (s *providerService) List() ([]domain.Provider, error) {
}
func (s *providerService) Update(id string, updates map[string]interface{}) error {
if _, ok := updates["id"]; ok {
return appErrors.ErrImmutableField
}
return s.providerRepo.Update(id, updates)
}
func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id)
}
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
func (s *providerService) ListEnabledModels() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询)
func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return s.modelRepo.FindByProviderAndModelName(providerID, modelName)
}
// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误
func isUniqueConstraintError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate")
}

View File

@@ -4,5 +4,5 @@ import "nex/backend/internal/domain"
// RoutingService 路由服务接口
type RoutingService interface {
Route(modelName string) (*domain.RouteResult, error)
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
}

View File

@@ -16,8 +16,8 @@ func NewRoutingService(modelRepo repository.ModelRepository, providerRepo reposi
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
}
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.GetByModelName(modelName)
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil, appErrors.ErrModelNotFound
}

View File

@@ -13,7 +13,8 @@ import (
func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
@@ -28,7 +29,8 @@ func TestProviderService_Update(t *testing.T) {
func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err)
@@ -41,11 +43,12 @@ func TestModelService_Get(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
model, err := svc.Get("m1")
result, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelService_Update(t *testing.T) {
@@ -55,14 +58,15 @@ func TestModelService_Update(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"})
err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"})
require.NoError(t, err)
model, err := svc.Get("m1")
result, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4o", model.ModelName)
assert.Equal(t, "gpt-4o", result.ModelName)
}
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
@@ -72,9 +76,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"})
err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"})
assert.Error(t, err)
}
@@ -85,12 +90,13 @@ func TestModelService_Delete(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Delete("m1")
err := svc.Delete(model.ID)
require.NoError(t, err)
_, err = svc.Get("m1")
_, err = svc.Get(model.ID)
assert.Error(t, err)
}

View File

@@ -1,8 +1,10 @@
package service
import (
"errors"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
@@ -11,6 +13,7 @@ import (
"nex/backend/internal/config"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
)
func setupServiceTestDB(t *testing.T) *gorm.DB {
@@ -29,80 +32,106 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
return db
}
// ============ ProviderService 测试 ============
// ============ RoutingService - RouteByModelName 测试 ============
func TestProviderService_Create(t *testing.T) {
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
provider := &domain.Provider{
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
}
err := svc.Create(provider)
// 创建供应商和模型
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
result, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err)
assert.True(t, provider.Enabled)
assert.Equal(t, "openai", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
}
func TestProviderService_Get_MaskKey(t *testing.T) {
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
})
result, err := svc.Get("p1", true)
require.NoError(t, err)
assert.Equal(t, "***2345", result.APIKey)
result, err = svc.Get("p1", false)
require.NoError(t, err)
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
_, err := svc.RouteByModelName("openai", "nonexistent-model")
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestProviderService_List(t *testing.T) {
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
// 创建启用的供应商和禁用的模型
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
providers, err := svc.List()
require.NoError(t, err)
assert.Len(t, providers, 2)
assert.Contains(t, providers[0].APIKey, "***")
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
}
func TestProviderService_Delete(t *testing.T) {
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
err := svc.Delete("p1")
require.NoError(t, err)
// 创建启用的供应商和模型,然后禁用供应商
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
_, err = svc.Get("p1", false)
assert.Error(t, err)
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
}
// ============ ModelService 测试 ============
// ============ ModelService - Create with UUID 测试 ============
func TestModelService_Create(t *testing.T) {
func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
assert.True(t, model.Enabled)
// 验证返回的 model 拥有有效的 UUID
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
// 通过 Get 验证持久化
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
assert.Equal(t, "gpt-4", stored.ModelName)
}
func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
func TestModelService_Create_ProviderNotFound(t *testing.T) {
@@ -111,160 +140,135 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
assert.Error(t, err)
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
}
func TestModelService_List(t *testing.T) {
// ============ ProviderService - Create with validation 测试 ============
func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
}
func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
assert.True(t, provider.Enabled)
}
// ============ ModelService - Update with duplicate check 测试 ============
func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
models, err := svc.List("p1")
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
assert.Len(t, models, 2)
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
err = svc.Create(model2)
require.NoError(t, err)
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
err = svc.Update(model2.ID, map[string]interface{}{
"provider_id": "openai",
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
// ============ RoutingService 测试 ============
func TestRoutingService_Route(t *testing.T) {
func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
result, err := svc.Route("gpt-4")
require.NoError(t, err)
assert.Equal(t, "p1", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
err := svc.Update("nonexistent-id", map[string]interface{}{
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc := NewModelService(modelRepo, providerRepo)
_, err := svc.Route("nonexistent-model")
assert.Error(t, err)
}
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
// 先创建启用的模型,然后通过 Update 禁用
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
// 先创建启用的 provider然后禁用
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
// ============ StatsService 测试 ============
func TestStatsService_RecordAndGet(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
err := svc.Record("p1", "gpt-4")
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
stats, err := svc.Get("p1", "", nil, nil)
// 更新 model_name 为不冲突的值
err = svc.Update(model.ID, map[string]interface{}{
"model_name": "gpt-4-turbo",
})
require.NoError(t, err)
assert.Len(t, stats, 1)
updated, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
}
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
// ============ ProviderService - Update immutable ID 测试 ============
stats := []domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
}
func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
result := svc.Aggregate(stats, "provider")
assert.Len(t, result, 2)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
p1Count := 0
p2Count := 0
for _, r := range result {
if r["provider_id"] == "p1" {
p1Count = r["request_count"].(int)
}
if r["provider_id"] == "p2" {
p2Count = r["request_count"].(int)
}
}
assert.Equal(t, 15, p1Count)
assert.Equal(t, 8, p2Count)
// 尝试更新 id 字段
err = svc.Update("openai", map[string]interface{}{
"id": "new-id",
})
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
}
func TestStatsService_Aggregate_ByDate(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
{ProviderID: "p2", RequestCount: 5},
}
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
result := svc.Aggregate(stats, "date")
assert.Len(t, result, 1)
assert.Equal(t, 15, result[0]["request_count"])
}
func TestStatsService_Aggregate_ByModel(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
}
result := svc.Aggregate(stats, "model")
assert.Len(t, result, 3)
// 验证每个 provider/model 组合的计数
counts := make(map[string]int)
for _, r := range result {
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
counts[key] = r["request_count"].(int)
}
assert.Equal(t, 13, counts["openai/gpt-4"])
assert.Equal(t, 5, counts["openai/gpt-3.5"])
assert.Equal(t, 8, counts["anthropic/claude-3"])
// 更新 name
err = svc.Update("openai", map[string]interface{}{
"name": "OpenAI Updated",
})
require.NoError(t, err)
updated, err := svc.Get("openai", false)
require.NoError(t, err)
assert.Equal(t, "OpenAI Updated", updated.Name)
}