package openai import ( "encoding/json" "fmt" "strings" "nex/backend/internal/conversion" "nex/backend/internal/conversion/canonical" ) // Adapter OpenAI 协议适配器 type Adapter struct{} // NewAdapter 创建 OpenAI 适配器 func NewAdapter() *Adapter { return &Adapter{} } // ProtocolName 返回协议名称 func (a *Adapter) ProtocolName() string { return "openai" } // ProtocolVersion 返回协议版本 func (a *Adapter) ProtocolVersion() string { return "" } // SupportsPassthrough 支持同协议透传 func (a *Adapter) SupportsPassthrough() bool { return true } // DetectInterfaceType 根据路径检测接口类型 func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType { switch { case nativePath == "/chat/completions": return conversion.InterfaceTypeChat case nativePath == "/models": return conversion.InterfaceTypeModels case isModelInfoPath(nativePath): return conversion.InterfaceTypeModelInfo case nativePath == "/embeddings": return conversion.InterfaceTypeEmbeddings case nativePath == "/rerank": return conversion.InterfaceTypeRerank default: return conversion.InterfaceTypePassthrough } } // isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 /) func isModelInfoPath(path string) bool { if !strings.HasPrefix(path, "/models/") { return false } suffix := path[len("/models/"):] return suffix != "" } // BuildUrl 根据接口类型构建 URL func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string { switch interfaceType { case conversion.InterfaceTypeChat: return "/chat/completions" case conversion.InterfaceTypeModels: return "/models" case conversion.InterfaceTypeEmbeddings: return "/embeddings" case conversion.InterfaceTypeRerank: return "/rerank" default: return nativePath } } // BuildHeaders 构建请求头 func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string { headers := map[string]string{ "Authorization": "Bearer " + provider.APIKey, "Content-Type": "application/json", } if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" { headers["OpenAI-Organization"] = org } return headers } // SupportsInterface 检查是否支持接口类型 func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool { switch interfaceType { case conversion.InterfaceTypeChat, conversion.InterfaceTypeModels, conversion.InterfaceTypeModelInfo, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank: return true default: return false } } // DecodeRequest 解码请求 func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) { return decodeRequest(raw) } // EncodeRequest 编码请求 func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) { return encodeRequest(req, provider) } // DecodeResponse 解码响应 func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) { return decodeResponse(raw) } // EncodeResponse 编码响应 func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { return encodeResponse(resp) } // CreateStreamDecoder 创建流式解码器 func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder { return NewStreamDecoder() } // CreateStreamEncoder 创建流式编码器 func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder { return NewStreamEncoder() } // EncodeError 编码错误 func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) { errType := mapErrorCode(err.Code) statusCode := 500 errMsg := ErrorResponse{ Error: ErrorDetail{ Message: err.Message, Type: errType, Param: nil, Code: string(err.Code), }, } body, _ := json.Marshal(errMsg) return body, statusCode } // mapErrorCode 映射错误码到 OpenAI 错误类型 func mapErrorCode(code conversion.ErrorCode) string { switch code { case conversion.ErrorCodeInvalidInput, conversion.ErrorCodeMissingRequiredField, conversion.ErrorCodeIncompatibleFeature, conversion.ErrorCodeToolCallParseError, conversion.ErrorCodeJSONParseError, conversion.ErrorCodeProtocolConstraint, conversion.ErrorCodeFieldMappingFailure: return "invalid_request_error" default: return "server_error" } } // DecodeModelsResponse 解码模型列表响应 func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) { return decodeModelsResponse(raw) } // EncodeModelsResponse 编码模型列表响应 func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { return encodeModelsResponse(list) } // DecodeModelInfoResponse 解码模型详情响应 func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) { return decodeModelInfoResponse(raw) } // EncodeModelInfoResponse 编码模型详情响应 func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { return encodeModelInfoResponse(info) } // DecodeEmbeddingRequest 解码嵌入请求 func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { return decodeEmbeddingRequest(raw) } // EncodeEmbeddingRequest 编码嵌入请求 func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) { return encodeEmbeddingRequest(req, provider) } // DecodeEmbeddingResponse 解码嵌入响应 func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) { return decodeEmbeddingResponse(raw) } // EncodeEmbeddingResponse 编码嵌入响应 func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) { return encodeEmbeddingResponse(resp) } // DecodeRerankRequest 解码重排序请求 func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) { return decodeRerankRequest(raw) } // EncodeRerankRequest 编码重排序请求 func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) { return encodeRerankRequest(req, provider) } // DecodeRerankResponse 解码重排序响应 func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) { return decodeRerankResponse(raw) } // EncodeRerankResponse 编码重排序响应 func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { return encodeRerankResponse(resp) } // ExtractUnifiedModelID 从路径中提取统一模型 ID(/models/{provider_id}/{model_name}) func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) { if !strings.HasPrefix(nativePath, "/models/") { return "", fmt.Errorf("不是模型详情路径: %s", nativePath) } suffix := nativePath[len("/models/"):] if suffix == "" { return "", fmt.Errorf("路径缺少模型 ID") } return suffix, nil } // locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数 func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) { var m map[string]json.RawMessage if err := json.Unmarshal(body, &m); err != nil { return "", nil, err } switch ifaceType { case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank: raw, exists := m["model"] if !exists { return "", nil, fmt.Errorf("请求体中缺少 model 字段") } var current string if err := json.Unmarshal(raw, ¤t); err != nil { return "", nil, err } rewriteFunc := func(newModel string) ([]byte, error) { m["model"], _ = json.Marshal(newModel) return json.Marshal(m) } return current, rewriteFunc, nil default: return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType) } } // ExtractModelName 从请求体中提取 model 值 func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) { model, _, err := locateModelFieldInRequest(body, ifaceType) return model, err } // RewriteRequestModelName 最小化改写请求体中的 model 字段 func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) { _, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType) if err != nil { return nil, err } return rewriteFunc(newModel) } // RewriteResponseModelName 最小化改写响应体中的 model 字段 func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) { var m map[string]json.RawMessage if err := json.Unmarshal(body, &m); err != nil { return nil, err } switch ifaceType { case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings: // Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加 m["model"], _ = json.Marshal(newModel) return json.Marshal(m) case conversion.InterfaceTypeRerank: // Rerank 响应:存在 model 字段则改写,不存在则不添加 if _, exists := m["model"]; exists { m["model"], _ = json.Marshal(newModel) } return json.Marshal(m) default: return body, nil } }