引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间 无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化 ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
339 lines
11 KiB
Go
339 lines
11 KiB
Go
package conversion
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// HTTPRequestSpec HTTP 请求规格
|
|
type HTTPRequestSpec struct {
|
|
URL string `json:"url"`
|
|
Method string `json:"method"`
|
|
Headers map[string]string `json:"headers"`
|
|
Body []byte `json:"body"`
|
|
}
|
|
|
|
// HTTPResponseSpec HTTP 响应规格
|
|
type HTTPResponseSpec struct {
|
|
StatusCode int `json:"status_code"`
|
|
Headers map[string]string `json:"headers"`
|
|
Body []byte `json:"body"`
|
|
}
|
|
|
|
// ConversionEngine 转换引擎门面
|
|
type ConversionEngine struct {
|
|
registry AdapterRegistry
|
|
middlewareChain *MiddlewareChain
|
|
}
|
|
|
|
// NewConversionEngine 创建转换引擎
|
|
func NewConversionEngine(registry AdapterRegistry) *ConversionEngine {
|
|
return &ConversionEngine{
|
|
registry: registry,
|
|
middlewareChain: NewMiddlewareChain(),
|
|
}
|
|
}
|
|
|
|
// RegisterAdapter 注册协议适配器
|
|
func (e *ConversionEngine) RegisterAdapter(adapter ProtocolAdapter) error {
|
|
return e.registry.Register(adapter)
|
|
}
|
|
|
|
// GetRegistry 返回注册表(供外部使用)
|
|
func (e *ConversionEngine) GetRegistry() AdapterRegistry {
|
|
return e.registry
|
|
}
|
|
|
|
// Use 添加中间件
|
|
func (e *ConversionEngine) Use(mw ConversionMiddleware) {
|
|
e.middlewareChain.Use(mw)
|
|
}
|
|
|
|
// IsPassthrough 判断是否同协议透传
|
|
func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string) bool {
|
|
if clientProtocol != providerProtocol {
|
|
return false
|
|
}
|
|
adapter, err := e.registry.Get(clientProtocol)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return adapter.SupportsPassthrough()
|
|
}
|
|
|
|
// ConvertHttpRequest 转换 HTTP 请求
|
|
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
|
nativePath := spec.URL
|
|
|
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
|
providerAdapter, err := e.registry.Get(providerProtocol)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &HTTPRequestSpec{
|
|
URL: provider.BaseURL + nativePath,
|
|
Method: spec.Method,
|
|
Headers: providerAdapter.BuildHeaders(provider),
|
|
Body: spec.Body,
|
|
}, nil
|
|
}
|
|
|
|
clientAdapter, err := e.registry.Get(clientProtocol)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("未找到客户端适配器 %s: %w", clientProtocol, err)
|
|
}
|
|
providerAdapter, err := e.registry.Get(providerProtocol)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("未找到服务端适配器 %s: %w", providerProtocol, err)
|
|
}
|
|
|
|
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
|
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
|
providerHeaders := providerAdapter.BuildHeaders(provider)
|
|
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &HTTPRequestSpec{
|
|
URL: provider.BaseURL + providerUrl,
|
|
Method: spec.Method,
|
|
Headers: providerHeaders,
|
|
Body: providerBody,
|
|
}, nil
|
|
}
|
|
|
|
// ConvertHttpResponse 转换 HTTP 响应
|
|
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
|
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
|
return &spec, nil
|
|
}
|
|
|
|
clientAdapter, err := e.registry.Get(clientProtocol)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
providerAdapter, err := e.registry.Get(providerProtocol)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &HTTPResponseSpec{
|
|
StatusCode: spec.StatusCode,
|
|
Headers: spec.Headers,
|
|
Body: convertedBody,
|
|
}, nil
|
|
}
|
|
|
|
// CreateStreamConverter 创建流式转换器
|
|
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
|
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
|
return NewPassthroughStreamConverter(), nil
|
|
}
|
|
|
|
providerAdapter, err := e.registry.Get(providerProtocol)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientAdapter, err := e.registry.Get(clientProtocol)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ctx := ConversionContext{
|
|
ConversionID: uuid.New().String(),
|
|
InterfaceType: InterfaceTypeChat,
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
return NewCanonicalStreamConverterWithMiddleware(
|
|
providerAdapter.CreateStreamDecoder(),
|
|
clientAdapter.CreateStreamEncoder(),
|
|
e.middlewareChain,
|
|
ctx,
|
|
clientProtocol,
|
|
providerProtocol,
|
|
), nil
|
|
}
|
|
|
|
// convertBody 转换请求体
|
|
func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
|
switch interfaceType {
|
|
case InterfaceTypeChat:
|
|
return e.convertChatBody(clientAdapter, providerAdapter, provider, body)
|
|
case InterfaceTypeModels, InterfaceTypeModelInfo:
|
|
return body, nil
|
|
case InterfaceTypeEmbeddings:
|
|
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
|
return body, nil
|
|
}
|
|
return e.convertEmbeddingBody(clientAdapter, providerAdapter, provider, body)
|
|
case InterfaceTypeRerank:
|
|
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
|
return body, nil
|
|
}
|
|
return e.convertRerankBody(clientAdapter, providerAdapter, provider, body)
|
|
default:
|
|
return body, nil
|
|
}
|
|
}
|
|
|
|
// convertResponseBody 转换响应体
|
|
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
|
switch interfaceType {
|
|
case InterfaceTypeChat:
|
|
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
|
|
case InterfaceTypeModels:
|
|
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
|
return body, nil
|
|
}
|
|
return e.convertModelsResponseBody(clientAdapter, providerAdapter, body)
|
|
case InterfaceTypeModelInfo:
|
|
if !clientAdapter.SupportsInterface(InterfaceTypeModelInfo) || !providerAdapter.SupportsInterface(InterfaceTypeModelInfo) {
|
|
return body, nil
|
|
}
|
|
return e.convertModelInfoResponseBody(clientAdapter, providerAdapter, body)
|
|
case InterfaceTypeEmbeddings:
|
|
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
|
return body, nil
|
|
}
|
|
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
|
|
case InterfaceTypeRerank:
|
|
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
|
return body, nil
|
|
}
|
|
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
|
|
default:
|
|
return body, nil
|
|
}
|
|
}
|
|
|
|
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
|
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
|
if err != nil {
|
|
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
|
|
}
|
|
|
|
ctx := NewConversionContext(InterfaceTypeChat)
|
|
canonicalReq, err = e.middlewareChain.Apply(canonicalReq, clientAdapter.ProtocolName(), providerAdapter.ProtocolName(), ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
|
if err != nil {
|
|
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码请求失败").WithCause(err)
|
|
}
|
|
return encoded, nil
|
|
}
|
|
|
|
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
|
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
|
if err != nil {
|
|
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
|
}
|
|
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
|
if err != nil {
|
|
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
|
}
|
|
return encoded, nil
|
|
}
|
|
|
|
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
|
models, err := providerAdapter.DecodeModelsResponse(body)
|
|
if err != nil {
|
|
zap.L().Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
|
return body, nil
|
|
}
|
|
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
|
if err != nil {
|
|
zap.L().Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
|
return body, nil
|
|
}
|
|
return encoded, nil
|
|
}
|
|
|
|
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
|
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
|
if err != nil {
|
|
zap.L().Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
|
return body, nil
|
|
}
|
|
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
|
if err != nil {
|
|
zap.L().Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
|
return body, nil
|
|
}
|
|
return encoded, nil
|
|
}
|
|
|
|
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
|
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
|
if err != nil {
|
|
zap.L().Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
|
return body, nil
|
|
}
|
|
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
|
}
|
|
|
|
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
|
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
|
if err != nil {
|
|
zap.L().Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
|
return body, nil
|
|
}
|
|
return clientAdapter.EncodeEmbeddingResponse(resp)
|
|
}
|
|
|
|
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
|
req, err := clientAdapter.DecodeRerankRequest(body)
|
|
if err != nil {
|
|
zap.L().Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
|
return body, nil
|
|
}
|
|
return providerAdapter.EncodeRerankRequest(req, provider)
|
|
}
|
|
|
|
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
|
resp, err := providerAdapter.DecodeRerankResponse(body)
|
|
if err != nil {
|
|
return body, nil
|
|
}
|
|
return clientAdapter.EncodeRerankResponse(resp)
|
|
}
|
|
|
|
// DetectInterfaceType 检测接口类型
|
|
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
|
|
adapter, err := e.registry.Get(clientProtocol)
|
|
if err != nil {
|
|
return InterfaceTypePassthrough, err
|
|
}
|
|
return adapter.DetectInterfaceType(nativePath), nil
|
|
}
|
|
|
|
// EncodeError 使用客户端适配器编码错误
|
|
func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol string) ([]byte, int, error) {
|
|
adapter, adapterErr := e.registry.Get(clientProtocol)
|
|
if adapterErr != nil {
|
|
fallback := map[string]any{
|
|
"error": map[string]string{
|
|
"message": err.Error(),
|
|
"type": "internal_error",
|
|
},
|
|
}
|
|
body, _ := json.Marshal(fallback)
|
|
return body, 500, nil
|
|
}
|
|
body, statusCode := adapter.EncodeError(err)
|
|
return body, statusCode, nil
|
|
}
|