package openai import ( "encoding/json" "strings" "nex/backend/internal/conversion" "nex/backend/internal/conversion/canonical" ) // Adapter OpenAI 协议适配器 type Adapter struct{} // NewAdapter 创建 OpenAI 适配器 func NewAdapter() *Adapter { return &Adapter{} } // ProtocolName 返回协议名称 func (a *Adapter) ProtocolName() string { return "openai" } // ProtocolVersion 返回协议版本 func (a *Adapter) ProtocolVersion() string { return "" } // SupportsPassthrough 支持同协议透传 func (a *Adapter) SupportsPassthrough() bool { return true } // DetectInterfaceType 根据路径检测接口类型 func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType { switch { case nativePath == "/v1/chat/completions": return conversion.InterfaceTypeChat case nativePath == "/v1/models": return conversion.InterfaceTypeModels case isModelInfoPath(nativePath): return conversion.InterfaceTypeModelInfo case nativePath == "/v1/embeddings": return conversion.InterfaceTypeEmbeddings case nativePath == "/v1/rerank": return conversion.InterfaceTypeRerank default: return conversion.InterfaceTypePassthrough } } // isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}) func isModelInfoPath(path string) bool { if !strings.HasPrefix(path, "/v1/models/") { return false } suffix := path[len("/v1/models/"):] return suffix != "" && !strings.Contains(suffix, "/") } // BuildUrl 根据接口类型构建 URL func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string { switch interfaceType { case conversion.InterfaceTypeChat: return "/v1/chat/completions" case conversion.InterfaceTypeModels: return "/v1/models" case conversion.InterfaceTypeEmbeddings: return "/v1/embeddings" case conversion.InterfaceTypeRerank: return "/v1/rerank" default: return nativePath } } // BuildHeaders 构建请求头 func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string { headers := map[string]string{ "Authorization": "Bearer " + provider.APIKey, "Content-Type": "application/json", } if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" { headers["OpenAI-Organization"] = org } return headers } // SupportsInterface 检查是否支持接口类型 func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool { switch interfaceType { case conversion.InterfaceTypeChat, conversion.InterfaceTypeModels, conversion.InterfaceTypeModelInfo, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank: return true default: return false } } // DecodeRequest 解码请求 func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) { return decodeRequest(raw) } // EncodeRequest 编码请求 func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) { return encodeRequest(req, provider) } // DecodeResponse 解码响应 func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) { return decodeResponse(raw) } // EncodeResponse 编码响应 func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) { return encodeResponse(resp) } // CreateStreamDecoder 创建流式解码器 func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder { return NewStreamDecoder() } // CreateStreamEncoder 创建流式编码器 func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder { return NewStreamEncoder() } // EncodeError 编码错误 func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) { errType := mapErrorCode(err.Code) statusCode := 500 errMsg := ErrorResponse{ Error: ErrorDetail{ Message: err.Message, Type: errType, Param: nil, Code: string(err.Code), }, } body, _ := json.Marshal(errMsg) return body, statusCode } // mapErrorCode 映射错误码到 OpenAI 错误类型 func mapErrorCode(code conversion.ErrorCode) string { switch code { case conversion.ErrorCodeInvalidInput, conversion.ErrorCodeMissingRequiredField, conversion.ErrorCodeIncompatibleFeature, conversion.ErrorCodeToolCallParseError, conversion.ErrorCodeJSONParseError, conversion.ErrorCodeProtocolConstraint, conversion.ErrorCodeFieldMappingFailure: return "invalid_request_error" default: return "server_error" } } // DecodeModelsResponse 解码模型列表响应 func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) { return decodeModelsResponse(raw) } // EncodeModelsResponse 编码模型列表响应 func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) { return encodeModelsResponse(list) } // DecodeModelInfoResponse 解码模型详情响应 func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) { return decodeModelInfoResponse(raw) } // EncodeModelInfoResponse 编码模型详情响应 func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) { return encodeModelInfoResponse(info) } // DecodeEmbeddingRequest 解码嵌入请求 func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { return decodeEmbeddingRequest(raw) } // EncodeEmbeddingRequest 编码嵌入请求 func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) { return encodeEmbeddingRequest(req, provider) } // DecodeEmbeddingResponse 解码嵌入响应 func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) { return decodeEmbeddingResponse(raw) } // EncodeEmbeddingResponse 编码嵌入响应 func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) { return encodeEmbeddingResponse(resp) } // DecodeRerankRequest 解码重排序请求 func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) { return decodeRerankRequest(raw) } // EncodeRerankRequest 编码重排序请求 func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) { return encodeRerankRequest(req, provider) } // DecodeRerankResponse 解码重排序响应 func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) { return decodeRerankResponse(raw) } // EncodeRerankResponse 编码重排序响应 func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { return encodeRerankResponse(resp) }