feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。 核心变更: - 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID - 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束 - Repository 层:FindByProviderAndModelName、ListEnabled 方法 - Service 层:联合唯一校验、provider ID 字符集校验 - Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法 - Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合 - 新增 error-responses、unified-model-id 规范 测试覆盖: - 单元测试:modelid、conversion、handler、service、repository - 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换 - 迁移测试:UUID 主键、UNIQUE 约束、级联删除 OpenSpec: - 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id - 同步 11 个 delta specs 到 main specs - 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
@@ -40,6 +40,12 @@ type ProtocolAdapter interface {
|
||||
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
|
||||
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
|
||||
|
||||
// 统一模型 ID 相关方法
|
||||
ExtractUnifiedModelID(nativePath string) (string, error)
|
||||
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
|
||||
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
}
|
||||
|
||||
// AdapterRegistry 适配器注册表接口
|
||||
|
||||
@@ -2,6 +2,7 @@ package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != "" && !strings.Contains(suffix, "/")
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
@@ -203,3 +204,74 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "some/deep/nested/model", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/messages")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", model)
|
||||
})
|
||||
|
||||
t.Run("chat_with_max_tokens", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3-opus", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
|
||||
// max_tokens is encoded as float in JSON numbers
|
||||
maxTokens, ok := m["max_tokens"]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(1024), maxTokens)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
|
||||
// other fields preserved
|
||||
_, hasContent := m["content"]
|
||||
assert.True(t, hasContent)
|
||||
assert.Equal(t, "end_turn", m["stop_reason"])
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
|
||||
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, float64(1024), m["max_tokens"])
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/v1/models/claude-3", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"messages_path", "/v1/messages", false},
|
||||
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -79,11 +79,29 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
|
||||
rewrittenBody := spec.Body
|
||||
|
||||
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
|
||||
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
|
||||
if len(spec.Body) > 0 && provider.ModelName != "" {
|
||||
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||||
if err != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||||
zap.String("error", err.Error()),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
rewrittenBody = spec.Body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + nativePath,
|
||||
Method: spec.Method,
|
||||
Headers: providerAdapter.BuildHeaders(provider),
|
||||
Body: spec.Body,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -112,9 +130,30 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConvertHttpResponse 转换 HTTP 响应
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
|
||||
// ConvertHttpResponse 转换 HTTP 响应,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
if modelOverride != "" && len(spec.Body) > 0 {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if err != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||
zap.String("error", err.Error()),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
@@ -127,7 +166,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -139,9 +178,17 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateStreamConverter 创建流式转换器
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
|
||||
// CreateStreamConverter 创建流式转换器,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||
if modelOverride != "" {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||
}
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
|
||||
@@ -167,6 +214,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
ctx,
|
||||
clientProtocol,
|
||||
providerProtocol,
|
||||
modelOverride,
|
||||
), nil
|
||||
}
|
||||
|
||||
@@ -192,11 +240,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte
|
||||
}
|
||||
}
|
||||
|
||||
// convertResponseBody 转换响应体
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
// convertResponseBody 转换响应体,modelOverride 非空时在 canonical 层面覆写 Model 字段
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
switch interfaceType {
|
||||
case InterfaceTypeChat:
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeModels:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
||||
return body, nil
|
||||
@@ -211,12 +259,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeRerank:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -241,11 +289,14 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||||
}
|
||||
if modelOverride != "" {
|
||||
canonicalResp.Model = modelOverride
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
||||
@@ -290,12 +341,15 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
|
||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
return body, nil
|
||||
}
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
@@ -308,11 +362,14 @@ func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter Prot
|
||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
}
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
|
||||
}, "client", "provider", InterfaceTypeChat)
|
||||
}, "client", "provider", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Contains(t, string(result.Body), "resp-1")
|
||||
@@ -129,7 +129,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
|
||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat)
|
||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -189,7 +189,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeEmbeddings)
|
||||
}, "client", "provider", InterfaceTypeEmbeddings, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
@@ -207,7 +207,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeRerank)
|
||||
}, "client", "provider", InterfaceTypeRerank, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
@@ -242,7 +242,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
|
||||
}, "client", "provider", InterfaceTypeModels)
|
||||
}, "client", "provider", InterfaceTypeModels, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
@@ -259,7 +259,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
|
||||
}, "client", "provider", InterfaceTypeModelInfo)
|
||||
}, "client", "provider", InterfaceTypeModelInfo, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
@@ -13,16 +13,18 @@ import (
|
||||
|
||||
// mockProtocolAdapter 模拟协议适配器
|
||||
type mockProtocolAdapter struct {
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
}
|
||||
|
||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
@@ -155,6 +157,28 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteReqFn != nil {
|
||||
return m.rewriteReqFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteRespFn != nil {
|
||||
return m.rewriteRespFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// noopStreamDecoder 空流式解码器
|
||||
type noopStreamDecoder struct{}
|
||||
|
||||
@@ -309,7 +333,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
Body: []byte(`{"id":"123"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat)
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
@@ -320,7 +344,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai")
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*PassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
@@ -332,7 +356,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider")
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
@@ -380,3 +404,230 @@ func TestRegistry_GetNonExistent(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "未找到适配器")
|
||||
}
|
||||
|
||||
// ============ modelOverride 测试 ============
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{"model": resp.Model})
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"model":"native-model"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "provider/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
assert.Equal(t, "resp-1", resp["id"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 验证 chunk 改写
|
||||
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
// provider adapter 解码出含 model 的流式事件
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
|
||||
canonical.NewMessageStopEvent(),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
// client adapter 编码时输出 model 字段
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"type": string(event.Type),
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
|
||||
return [][]byte{data}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证类型是 CanonicalStreamConverter
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 处理一个 chunk,验证 model 被覆写为统一模型 ID
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
|
||||
|
||||
var startEvent map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
|
||||
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
// modelOverride 为空,不应覆写
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
|
||||
}
|
||||
|
||||
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test)
|
||||
type engineTestStreamDecoder struct {
|
||||
processFn func([]byte) []canonical.CanonicalStreamEvent
|
||||
flushFn func() []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
if d.processFn != nil {
|
||||
return d.processFn(raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
if d.flushFn != nil {
|
||||
return d.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test)
|
||||
type engineTestStreamEncoder struct {
|
||||
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
|
||||
flushFn func() [][]byte
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if e.encodeFn != nil {
|
||||
return e.encodeFn(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||
if e.flushFn != nil {
|
||||
return e.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -43,13 +44,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != "" && !strings.Contains(suffix, "/")
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
@@ -216,3 +217,80 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return encodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
case conversion.InterfaceTypeRerank:
|
||||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||
if _, exists := m["model"]; exists {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
}
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func TestIsModelInfoPath(t *testing.T) {
|
||||
{"model_info", "/v1/models/gpt-4", true},
|
||||
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"nested_path", "/v1/models/gpt-4/versions", false},
|
||||
{"nested_path", "/v1/models/gpt-4/versions", true},
|
||||
{"empty_suffix", "/v1/models/", false},
|
||||
{"unrelated", "/v1/chat/completions", false},
|
||||
{"partial_prefix", "/v1/model", false},
|
||||
|
||||
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", model)
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", model)
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
|
||||
// messages field preserved
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "text-embedding", m["model"])
|
||||
assert.Equal(t, "hello", m["input"])
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "rerank", m["model"])
|
||||
assert.Equal(t, "test", m["query"])
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("rerank_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"rerank","results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/rerank", m["model"])
|
||||
})
|
||||
|
||||
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
|
||||
body := []byte(`{"results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
_, hasModel := m["model"]
|
||||
assert.False(t, hasModel, "rerank response without model field should not have one added")
|
||||
})
|
||||
|
||||
t.Run("embedding_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"text-embedding","data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
|
||||
body := []byte(`{"data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "text-embedding", afterRewrite)
|
||||
})
|
||||
|
||||
t.Run("rerank_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rerank", afterRewrite)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/v1/models/gpt-4", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"chat_completions", "/v1/chat/completions", false},
|
||||
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,14 +38,52 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||
// 逐 chunk 改写 model 字段
|
||||
type SmartPassthroughStreamConverter struct {
|
||||
adapter ProtocolAdapter
|
||||
modelOverride string
|
||||
interfaceType InterfaceType
|
||||
}
|
||||
|
||||
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
|
||||
return &SmartPassthroughStreamConverter{
|
||||
adapter: adapter,
|
||||
modelOverride: modelOverride,
|
||||
interfaceType: interfaceType,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 改写 chunk 中的 model 字段
|
||||
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
if len(rawChunk) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
// 改写失败,返回原始 chunk
|
||||
return [][]byte{rawChunk}
|
||||
}
|
||||
|
||||
return [][]byte{rewrittenChunk}
|
||||
}
|
||||
|
||||
// Flush 无缓冲数据
|
||||
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||
type CanonicalStreamConverter struct {
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
modelOverride string
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverter 创建规范流式转换器
|
||||
@@ -57,18 +95,19 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter {
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
|
||||
return &CanonicalStreamConverter{
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
providerProtocol: providerProtocol,
|
||||
modelOverride: modelOverride,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 解码 → 中间件 → 编码管道
|
||||
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
|
||||
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
events := c.decoder.ProcessChunk(rawChunk)
|
||||
var result [][]byte
|
||||
@@ -80,6 +119,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
@@ -98,6 +138,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
@@ -105,3 +146,10 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
result = append(result, encoderChunks...)
|
||||
return result
|
||||
}
|
||||
|
||||
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
|
||||
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
|
||||
if c.modelOverride != "" && event.Message != nil {
|
||||
event.Message.Model = c.modelOverride
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
|
||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
@@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
|
||||
@@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.Flush()
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
|
||||
Reference in New Issue
Block a user