实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。 核心变更: - 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID - 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束 - Repository 层:FindByProviderAndModelName、ListEnabled 方法 - Service 层:联合唯一校验、provider ID 字符集校验 - Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法 - Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合 - 新增 error-responses、unified-model-id 规范 测试覆盖: - 单元测试:modelid、conversion、handler、service、repository - 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换 - 迁移测试:UUID 主键、UNIQUE 约束、级联删除 OpenSpec: - 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id - 同步 11 个 delta specs 到 main specs - 新增 error-responses、unified-model-id 规范文件
401 lines
14 KiB
Go
401 lines
14 KiB
Go
package conversion
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// HTTPRequestSpec HTTP 请求规格
|
||
type HTTPRequestSpec struct {
|
||
URL string `json:"url"`
|
||
Method string `json:"method"`
|
||
Headers map[string]string `json:"headers"`
|
||
Body []byte `json:"body"`
|
||
}
|
||
|
||
// HTTPResponseSpec HTTP 响应规格
|
||
type HTTPResponseSpec struct {
|
||
StatusCode int `json:"status_code"`
|
||
Headers map[string]string `json:"headers"`
|
||
Body []byte `json:"body"`
|
||
}
|
||
|
||
// ConversionEngine 转换引擎门面
|
||
type ConversionEngine struct {
|
||
registry AdapterRegistry
|
||
middlewareChain *MiddlewareChain
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewConversionEngine 创建转换引擎
|
||
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
|
||
if logger == nil {
|
||
logger = zap.L()
|
||
}
|
||
return &ConversionEngine{
|
||
registry: registry,
|
||
middlewareChain: NewMiddlewareChain(),
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// RegisterAdapter 注册协议适配器
|
||
func (e *ConversionEngine) RegisterAdapter(adapter ProtocolAdapter) error {
|
||
return e.registry.Register(adapter)
|
||
}
|
||
|
||
// GetRegistry 返回注册表(供外部使用)
|
||
func (e *ConversionEngine) GetRegistry() AdapterRegistry {
|
||
return e.registry
|
||
}
|
||
|
||
// Use 添加中间件
|
||
func (e *ConversionEngine) Use(mw ConversionMiddleware) {
|
||
e.middlewareChain.Use(mw)
|
||
}
|
||
|
||
// IsPassthrough 判断是否同协议透传
|
||
func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string) bool {
|
||
if clientProtocol != providerProtocol {
|
||
return false
|
||
}
|
||
adapter, err := e.registry.Get(clientProtocol)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return adapter.SupportsPassthrough()
|
||
}
|
||
|
||
// ConvertHttpRequest 转换 HTTP 请求
|
||
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
||
nativePath := spec.URL
|
||
|
||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
|
||
rewrittenBody := spec.Body
|
||
|
||
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
|
||
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
|
||
if len(spec.Body) > 0 && provider.ModelName != "" {
|
||
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||
if err != nil {
|
||
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||
zap.String("error", err.Error()),
|
||
zap.String("interface", string(interfaceType)))
|
||
rewrittenBody = spec.Body
|
||
}
|
||
}
|
||
}
|
||
|
||
return &HTTPRequestSpec{
|
||
URL: provider.BaseURL + nativePath,
|
||
Method: spec.Method,
|
||
Headers: providerAdapter.BuildHeaders(provider),
|
||
Body: rewrittenBody,
|
||
}, nil
|
||
}
|
||
|
||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("未找到客户端适配器 %s: %w", clientProtocol, err)
|
||
}
|
||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("未找到服务端适配器 %s: %w", providerProtocol, err)
|
||
}
|
||
|
||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &HTTPRequestSpec{
|
||
URL: provider.BaseURL + providerUrl,
|
||
Method: spec.Method,
|
||
Headers: providerHeaders,
|
||
Body: providerBody,
|
||
}, nil
|
||
}
|
||
|
||
// ConvertHttpResponse 转换 HTTP 响应,modelOverride 用于跨协议场景覆写 model 字段
|
||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
|
||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||
if modelOverride != "" && len(spec.Body) > 0 {
|
||
adapter, err := e.registry.Get(clientProtocol)
|
||
if err != nil {
|
||
return &spec, nil
|
||
}
|
||
|
||
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||
if err != nil {
|
||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||
zap.String("error", err.Error()),
|
||
zap.String("interface", string(interfaceType)))
|
||
return &spec, nil
|
||
}
|
||
|
||
return &HTTPResponseSpec{
|
||
StatusCode: spec.StatusCode,
|
||
Headers: spec.Headers,
|
||
Body: rewrittenBody,
|
||
}, nil
|
||
}
|
||
return &spec, nil
|
||
}
|
||
|
||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &HTTPResponseSpec{
|
||
StatusCode: spec.StatusCode,
|
||
Headers: spec.Headers,
|
||
Body: convertedBody,
|
||
}, nil
|
||
}
|
||
|
||
// CreateStreamConverter 创建流式转换器,modelOverride 用于跨协议场景覆写 model 字段
|
||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
|
||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||
if modelOverride != "" {
|
||
adapter, err := e.registry.Get(clientProtocol)
|
||
if err != nil {
|
||
return NewPassthroughStreamConverter(), nil
|
||
}
|
||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||
}
|
||
return NewPassthroughStreamConverter(), nil
|
||
}
|
||
|
||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
ctx := ConversionContext{
|
||
ConversionID: uuid.New().String(),
|
||
InterfaceType: InterfaceTypeChat,
|
||
Timestamp: time.Now(),
|
||
}
|
||
|
||
return NewCanonicalStreamConverterWithMiddleware(
|
||
providerAdapter.CreateStreamDecoder(),
|
||
clientAdapter.CreateStreamEncoder(),
|
||
e.middlewareChain,
|
||
ctx,
|
||
clientProtocol,
|
||
providerProtocol,
|
||
modelOverride,
|
||
), nil
|
||
}
|
||
|
||
// convertBody 转换请求体
|
||
func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||
switch interfaceType {
|
||
case InterfaceTypeChat:
|
||
return e.convertChatBody(clientAdapter, providerAdapter, provider, body)
|
||
case InterfaceTypeModels, InterfaceTypeModelInfo:
|
||
return body, nil
|
||
case InterfaceTypeEmbeddings:
|
||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||
return body, nil
|
||
}
|
||
return e.convertEmbeddingBody(clientAdapter, providerAdapter, provider, body)
|
||
case InterfaceTypeRerank:
|
||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||
return body, nil
|
||
}
|
||
return e.convertRerankBody(clientAdapter, providerAdapter, provider, body)
|
||
default:
|
||
return body, nil
|
||
}
|
||
}
|
||
|
||
// convertResponseBody 转换响应体,modelOverride 非空时在 canonical 层面覆写 Model 字段
|
||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||
switch interfaceType {
|
||
case InterfaceTypeChat:
|
||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||
case InterfaceTypeModels:
|
||
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
||
return body, nil
|
||
}
|
||
return e.convertModelsResponseBody(clientAdapter, providerAdapter, body)
|
||
case InterfaceTypeModelInfo:
|
||
if !clientAdapter.SupportsInterface(InterfaceTypeModelInfo) || !providerAdapter.SupportsInterface(InterfaceTypeModelInfo) {
|
||
return body, nil
|
||
}
|
||
return e.convertModelInfoResponseBody(clientAdapter, providerAdapter, body)
|
||
case InterfaceTypeEmbeddings:
|
||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||
return body, nil
|
||
}
|
||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||
case InterfaceTypeRerank:
|
||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||
return body, nil
|
||
}
|
||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||
default:
|
||
return body, nil
|
||
}
|
||
}
|
||
|
||
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
||
if err != nil {
|
||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
|
||
}
|
||
|
||
ctx := NewConversionContext(InterfaceTypeChat)
|
||
canonicalReq, err = e.middlewareChain.Apply(canonicalReq, clientAdapter.ProtocolName(), providerAdapter.ProtocolName(), ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
||
if err != nil {
|
||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码请求失败").WithCause(err)
|
||
}
|
||
return encoded, nil
|
||
}
|
||
|
||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||
if err != nil {
|
||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||
}
|
||
if modelOverride != "" {
|
||
canonicalResp.Model = modelOverride
|
||
}
|
||
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
||
if err != nil {
|
||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
||
}
|
||
return encoded, nil
|
||
}
|
||
|
||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||
models, err := providerAdapter.DecodeModelsResponse(body)
|
||
if err != nil {
|
||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||
return body, nil
|
||
}
|
||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||
if err != nil {
|
||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||
return body, nil
|
||
}
|
||
return encoded, nil
|
||
}
|
||
|
||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||
if err != nil {
|
||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||
return body, nil
|
||
}
|
||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||
if err != nil {
|
||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||
return body, nil
|
||
}
|
||
return encoded, nil
|
||
}
|
||
|
||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||
if err != nil {
|
||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||
return body, nil
|
||
}
|
||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||
}
|
||
|
||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||
if err != nil {
|
||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||
return body, nil
|
||
}
|
||
if modelOverride != "" {
|
||
resp.Model = modelOverride
|
||
}
|
||
return clientAdapter.EncodeEmbeddingResponse(resp)
|
||
}
|
||
|
||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||
req, err := clientAdapter.DecodeRerankRequest(body)
|
||
if err != nil {
|
||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||
return body, nil
|
||
}
|
||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||
}
|
||
|
||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||
if err != nil {
|
||
return body, nil
|
||
}
|
||
if modelOverride != "" {
|
||
resp.Model = modelOverride
|
||
}
|
||
return clientAdapter.EncodeRerankResponse(resp)
|
||
}
|
||
|
||
// DetectInterfaceType 检测接口类型
|
||
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
|
||
adapter, err := e.registry.Get(clientProtocol)
|
||
if err != nil {
|
||
return InterfaceTypePassthrough, err
|
||
}
|
||
return adapter.DetectInterfaceType(nativePath), nil
|
||
}
|
||
|
||
// EncodeError 使用客户端适配器编码错误
|
||
func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol string) ([]byte, int, error) {
|
||
adapter, adapterErr := e.registry.Get(clientProtocol)
|
||
if adapterErr != nil {
|
||
fallback := map[string]any{
|
||
"error": map[string]string{
|
||
"message": err.Error(),
|
||
"type": "internal_error",
|
||
},
|
||
}
|
||
body, _ := json.Marshal(fallback)
|
||
return body, 500, nil
|
||
}
|
||
body, statusCode := adapter.EncodeError(err)
|
||
return body, statusCode, nil
|
||
}
|