1
0
Files
nex/backend/internal/conversion/engine.go
lanyuanxiaoyao 395887667d feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
2026-04-21 18:14:10 +08:00

401 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package conversion
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// HTTPRequestSpec HTTP 请求规格
type HTTPRequestSpec struct {
URL string `json:"url"`
Method string `json:"method"`
Headers map[string]string `json:"headers"`
Body []byte `json:"body"`
}
// HTTPResponseSpec HTTP 响应规格
type HTTPResponseSpec struct {
StatusCode int `json:"status_code"`
Headers map[string]string `json:"headers"`
Body []byte `json:"body"`
}
// ConversionEngine 转换引擎门面
type ConversionEngine struct {
registry AdapterRegistry
middlewareChain *MiddlewareChain
logger *zap.Logger
}
// NewConversionEngine 创建转换引擎
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
if logger == nil {
logger = zap.L()
}
return &ConversionEngine{
registry: registry,
middlewareChain: NewMiddlewareChain(),
logger: logger,
}
}
// RegisterAdapter 注册协议适配器
func (e *ConversionEngine) RegisterAdapter(adapter ProtocolAdapter) error {
return e.registry.Register(adapter)
}
// GetRegistry 返回注册表(供外部使用)
func (e *ConversionEngine) GetRegistry() AdapterRegistry {
return e.registry
}
// Use 添加中间件
func (e *ConversionEngine) Use(mw ConversionMiddleware) {
e.middlewareChain.Use(mw)
}
// IsPassthrough 判断是否同协议透传
func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string) bool {
if clientProtocol != providerProtocol {
return false
}
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return false
}
return adapter.SupportsPassthrough()
}
// ConvertHttpRequest 转换 HTTP 请求
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
nativePath := spec.URL
if e.IsPassthrough(clientProtocol, providerProtocol) {
providerAdapter, err := e.registry.Get(providerProtocol)
if err != nil {
return nil, err
}
// Smart Passthrough: 同协议时最小化改写 model 字段
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
rewrittenBody := spec.Body
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
if len(spec.Body) > 0 && provider.ModelName != "" {
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
zap.String("error", err.Error()),
zap.String("interface", string(interfaceType)))
rewrittenBody = spec.Body
}
}
}
return &HTTPRequestSpec{
URL: provider.BaseURL + nativePath,
Method: spec.Method,
Headers: providerAdapter.BuildHeaders(provider),
Body: rewrittenBody,
}, nil
}
clientAdapter, err := e.registry.Get(clientProtocol)
if err != nil {
return nil, fmt.Errorf("未找到客户端适配器 %s: %w", clientProtocol, err)
}
providerAdapter, err := e.registry.Get(providerProtocol)
if err != nil {
return nil, fmt.Errorf("未找到服务端适配器 %s: %w", providerProtocol, err)
}
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil {
return nil, err
}
return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl,
Method: spec.Method,
Headers: providerHeaders,
Body: providerBody,
}, nil
}
// ConvertHttpResponse 转换 HTTP 响应modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return &spec, nil
}
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.String("error", err.Error()),
zap.String("interface", string(interfaceType)))
return &spec, nil
}
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
return &spec, nil
}
clientAdapter, err := e.registry.Get(clientProtocol)
if err != nil {
return nil, err
}
providerAdapter, err := e.registry.Get(providerProtocol)
if err != nil {
return nil, err
}
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
if err != nil {
return nil, err
}
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: convertedBody,
}, nil
}
// CreateStreamConverter 创建流式转换器modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return NewPassthroughStreamConverter(), nil
}
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
return NewPassthroughStreamConverter(), nil
}
providerAdapter, err := e.registry.Get(providerProtocol)
if err != nil {
return nil, err
}
clientAdapter, err := e.registry.Get(clientProtocol)
if err != nil {
return nil, err
}
ctx := ConversionContext{
ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat,
Timestamp: time.Now(),
}
return NewCanonicalStreamConverterWithMiddleware(
providerAdapter.CreateStreamDecoder(),
clientAdapter.CreateStreamEncoder(),
e.middlewareChain,
ctx,
clientProtocol,
providerProtocol,
modelOverride,
), nil
}
// convertBody 转换请求体
func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
switch interfaceType {
case InterfaceTypeChat:
return e.convertChatBody(clientAdapter, providerAdapter, provider, body)
case InterfaceTypeModels, InterfaceTypeModelInfo:
return body, nil
case InterfaceTypeEmbeddings:
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
return body, nil
}
return e.convertEmbeddingBody(clientAdapter, providerAdapter, provider, body)
case InterfaceTypeRerank:
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
return body, nil
}
return e.convertRerankBody(clientAdapter, providerAdapter, provider, body)
default:
return body, nil
}
}
// convertResponseBody 转换响应体modelOverride 非空时在 canonical 层面覆写 Model 字段
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
switch interfaceType {
case InterfaceTypeChat:
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeModels:
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
return body, nil
}
return e.convertModelsResponseBody(clientAdapter, providerAdapter, body)
case InterfaceTypeModelInfo:
if !clientAdapter.SupportsInterface(InterfaceTypeModelInfo) || !providerAdapter.SupportsInterface(InterfaceTypeModelInfo) {
return body, nil
}
return e.convertModelInfoResponseBody(clientAdapter, providerAdapter, body)
case InterfaceTypeEmbeddings:
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
return body, nil
}
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeRerank:
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
return body, nil
}
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
default:
return body, nil
}
}
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
canonicalReq, err := clientAdapter.DecodeRequest(body)
if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
}
ctx := NewConversionContext(InterfaceTypeChat)
canonicalReq, err = e.middlewareChain.Apply(canonicalReq, clientAdapter.ProtocolName(), providerAdapter.ProtocolName(), ctx)
if err != nil {
return nil, err
}
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
if err != nil {
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码请求失败").WithCause(err)
}
return encoded, nil
}
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
canonicalResp, err := providerAdapter.DecodeResponse(body)
if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
}
if modelOverride != "" {
canonicalResp.Model = modelOverride
}
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
if err != nil {
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
}
return encoded, nil
}
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
encoded, err := clientAdapter.EncodeModelsResponse(models)
if err != nil {
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
return encoded, nil
}
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
return encoded, nil
}
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
return body, nil
}
return providerAdapter.EncodeEmbeddingRequest(req, provider)
}
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil
}
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeEmbeddingResponse(resp)
}
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
return body, nil
}
return providerAdapter.EncodeRerankRequest(req, provider)
}
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body)
if err != nil {
return body, nil
}
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
}
// DetectInterfaceType 检测接口类型
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return InterfaceTypePassthrough, err
}
return adapter.DetectInterfaceType(nativePath), nil
}
// EncodeError 使用客户端适配器编码错误
func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol string) ([]byte, int, error) {
adapter, adapterErr := e.registry.Get(clientProtocol)
if adapterErr != nil {
fallback := map[string]any{
"error": map[string]string{
"message": err.Error(),
"type": "internal_error",
},
}
body, _ := json.Marshal(fallback)
return body, 500, nil
}
body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil
}