1
0
Files
nex/backend/internal/conversion/openai/adapter.go
lanyuanxiaoyao b7e205f4b6 refactor: 优化 URL 路径拼接,修复 /v1 重复问题
## 主要变更

**核心修改**:
- 路由定义:/:protocol/v1/*path → /:protocol/*path
- proxy_handler:nativePath 直接使用 path 参数,不添加 /v1 前缀
- OpenAI 适配器:DetectInterfaceType 和 BuildUrl 去掉 /v1 前缀
- Anthropic 适配器:保持 /v1 前缀(Claude Code 兼容)

**URL 格式变化**:
- OpenAI: /openai/v1/chat/completions → /openai/chat/completions
- Anthropic: /anthropic/v1/messages (保持不变)

**base_url 配置**:
- OpenAI: 配置到版本路径,如 https://api.openai.com/v1
- Anthropic: 不配置版本路径,如 https://api.anthropic.com

## 测试验证

- 所有单元测试通过
- 所有集成测试通过
- 真实 API 测试验证成功
- 跨协议转换正常工作

## 文档更新

- 更新 backend/README.md URL 格式说明
- 同步 OpenSpec 规范文件
2026-04-21 20:21:17 +08:00

297 lines
9.2 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package openai
import (
"encoding/json"
"fmt"
"strings"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
)
// Adapter OpenAI 协议适配器
type Adapter struct{}
// NewAdapter 创建 OpenAI 适配器
func NewAdapter() *Adapter {
return &Adapter{}
}
// ProtocolName 返回协议名称
func (a *Adapter) ProtocolName() string { return "openai" }
// ProtocolVersion 返回协议版本
func (a *Adapter) ProtocolVersion() string { return "" }
// SupportsPassthrough 支持同协议透传
func (a *Adapter) SupportsPassthrough() bool { return true }
// DetectInterfaceType 根据路径检测接口类型
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
switch {
case nativePath == "/chat/completions":
return conversion.InterfaceTypeChat
case nativePath == "/models":
return conversion.InterfaceTypeModels
case isModelInfoPath(nativePath):
return conversion.InterfaceTypeModelInfo
case nativePath == "/embeddings":
return conversion.InterfaceTypeEmbeddings
case nativePath == "/rerank":
return conversion.InterfaceTypeRerank
default:
return conversion.InterfaceTypePassthrough
}
}
// isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/models/") {
return false
}
suffix := path[len("/models/"):]
return suffix != ""
}
// BuildUrl 根据接口类型构建 URL
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
switch interfaceType {
case conversion.InterfaceTypeChat:
return "/chat/completions"
case conversion.InterfaceTypeModels:
return "/models"
case conversion.InterfaceTypeEmbeddings:
return "/embeddings"
case conversion.InterfaceTypeRerank:
return "/rerank"
default:
return nativePath
}
}
// BuildHeaders 构建请求头
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
headers := map[string]string{
"Authorization": "Bearer " + provider.APIKey,
"Content-Type": "application/json",
}
if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" {
headers["OpenAI-Organization"] = org
}
return headers
}
// SupportsInterface 检查是否支持接口类型
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
switch interfaceType {
case conversion.InterfaceTypeChat,
conversion.InterfaceTypeModels,
conversion.InterfaceTypeModelInfo,
conversion.InterfaceTypeEmbeddings,
conversion.InterfaceTypeRerank:
return true
default:
return false
}
}
// DecodeRequest 解码请求
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
return decodeRequest(raw)
}
// EncodeRequest 编码请求
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
return encodeRequest(req, provider)
}
// DecodeResponse 解码响应
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
return decodeResponse(raw)
}
// EncodeResponse 编码响应
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
return encodeResponse(resp)
}
// CreateStreamDecoder 创建流式解码器
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
return NewStreamDecoder()
}
// CreateStreamEncoder 创建流式编码器
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
return NewStreamEncoder()
}
// EncodeError 编码错误
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
errType := mapErrorCode(err.Code)
statusCode := 500
errMsg := ErrorResponse{
Error: ErrorDetail{
Message: err.Message,
Type: errType,
Param: nil,
Code: string(err.Code),
},
}
body, _ := json.Marshal(errMsg)
return body, statusCode
}
// mapErrorCode 映射错误码到 OpenAI 错误类型
func mapErrorCode(code conversion.ErrorCode) string {
switch code {
case conversion.ErrorCodeInvalidInput,
conversion.ErrorCodeMissingRequiredField,
conversion.ErrorCodeIncompatibleFeature,
conversion.ErrorCodeToolCallParseError,
conversion.ErrorCodeJSONParseError,
conversion.ErrorCodeProtocolConstraint,
conversion.ErrorCodeFieldMappingFailure:
return "invalid_request_error"
default:
return "server_error"
}
}
// DecodeModelsResponse 解码模型列表响应
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
return decodeModelsResponse(raw)
}
// EncodeModelsResponse 编码模型列表响应
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
return encodeModelsResponse(list)
}
// DecodeModelInfoResponse 解码模型详情响应
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
return decodeModelInfoResponse(raw)
}
// EncodeModelInfoResponse 编码模型详情响应
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
return encodeModelInfoResponse(info)
}
// DecodeEmbeddingRequest 解码嵌入请求
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
return decodeEmbeddingRequest(raw)
}
// EncodeEmbeddingRequest 编码嵌入请求
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
return encodeEmbeddingRequest(req, provider)
}
// DecodeEmbeddingResponse 解码嵌入响应
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
return decodeEmbeddingResponse(raw)
}
// EncodeEmbeddingResponse 编码嵌入响应
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
return encodeEmbeddingResponse(resp)
}
// DecodeRerankRequest 解码重排序请求
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
return decodeRerankRequest(raw)
}
// EncodeRerankRequest 编码重排序请求
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
return encodeRerankRequest(req, provider)
}
// DecodeRerankResponse 解码重排序响应
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
return decodeRerankResponse(raw)
}
// EncodeRerankResponse 编码重排序响应
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return encodeRerankResponse(resp)
}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel)
}
return json.Marshal(m)
default:
return body, nil
}
}