1
0
Files
nex/backend/internal/conversion/openai/adapter.go
lanyuanxiaoyao 395887667d feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
2026-04-21 18:14:10 +08:00

297 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package openai
import (
"encoding/json"
"fmt"
"strings"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
)
// Adapter OpenAI 协议适配器
type Adapter struct{}
// NewAdapter 创建 OpenAI 适配器
func NewAdapter() *Adapter {
return &Adapter{}
}
// ProtocolName 返回协议名称
func (a *Adapter) ProtocolName() string { return "openai" }
// ProtocolVersion 返回协议版本
func (a *Adapter) ProtocolVersion() string { return "" }
// SupportsPassthrough 支持同协议透传
func (a *Adapter) SupportsPassthrough() bool { return true }
// DetectInterfaceType 根据路径检测接口类型
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
switch {
case nativePath == "/v1/chat/completions":
return conversion.InterfaceTypeChat
case nativePath == "/v1/models":
return conversion.InterfaceTypeModels
case isModelInfoPath(nativePath):
return conversion.InterfaceTypeModelInfo
case nativePath == "/v1/embeddings":
return conversion.InterfaceTypeEmbeddings
case nativePath == "/v1/rerank":
return conversion.InterfaceTypeRerank
default:
return conversion.InterfaceTypePassthrough
}
}
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/v1/models/"):]
return suffix != ""
}
// BuildUrl 根据接口类型构建 URL
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
switch interfaceType {
case conversion.InterfaceTypeChat:
return "/v1/chat/completions"
case conversion.InterfaceTypeModels:
return "/v1/models"
case conversion.InterfaceTypeEmbeddings:
return "/v1/embeddings"
case conversion.InterfaceTypeRerank:
return "/v1/rerank"
default:
return nativePath
}
}
// BuildHeaders 构建请求头
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
headers := map[string]string{
"Authorization": "Bearer " + provider.APIKey,
"Content-Type": "application/json",
}
if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" {
headers["OpenAI-Organization"] = org
}
return headers
}
// SupportsInterface 检查是否支持接口类型
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
switch interfaceType {
case conversion.InterfaceTypeChat,
conversion.InterfaceTypeModels,
conversion.InterfaceTypeModelInfo,
conversion.InterfaceTypeEmbeddings,
conversion.InterfaceTypeRerank:
return true
default:
return false
}
}
// DecodeRequest 解码请求
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
return decodeRequest(raw)
}
// EncodeRequest 编码请求
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
return encodeRequest(req, provider)
}
// DecodeResponse 解码响应
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
return decodeResponse(raw)
}
// EncodeResponse 编码响应
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
return encodeResponse(resp)
}
// CreateStreamDecoder 创建流式解码器
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
return NewStreamDecoder()
}
// CreateStreamEncoder 创建流式编码器
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
return NewStreamEncoder()
}
// EncodeError 编码错误
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
errType := mapErrorCode(err.Code)
statusCode := 500
errMsg := ErrorResponse{
Error: ErrorDetail{
Message: err.Message,
Type: errType,
Param: nil,
Code: string(err.Code),
},
}
body, _ := json.Marshal(errMsg)
return body, statusCode
}
// mapErrorCode 映射错误码到 OpenAI 错误类型
func mapErrorCode(code conversion.ErrorCode) string {
switch code {
case conversion.ErrorCodeInvalidInput,
conversion.ErrorCodeMissingRequiredField,
conversion.ErrorCodeIncompatibleFeature,
conversion.ErrorCodeToolCallParseError,
conversion.ErrorCodeJSONParseError,
conversion.ErrorCodeProtocolConstraint,
conversion.ErrorCodeFieldMappingFailure:
return "invalid_request_error"
default:
return "server_error"
}
}
// DecodeModelsResponse 解码模型列表响应
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
return decodeModelsResponse(raw)
}
// EncodeModelsResponse 编码模型列表响应
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
return encodeModelsResponse(list)
}
// DecodeModelInfoResponse 解码模型详情响应
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
return decodeModelInfoResponse(raw)
}
// EncodeModelInfoResponse 编码模型详情响应
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
return encodeModelInfoResponse(info)
}
// DecodeEmbeddingRequest 解码嵌入请求
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
return decodeEmbeddingRequest(raw)
}
// EncodeEmbeddingRequest 编码嵌入请求
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
return encodeEmbeddingRequest(req, provider)
}
// DecodeEmbeddingResponse 解码嵌入响应
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
return decodeEmbeddingResponse(raw)
}
// EncodeEmbeddingResponse 编码嵌入响应
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
return encodeEmbeddingResponse(resp)
}
// DecodeRerankRequest 解码重排序请求
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
return decodeRerankRequest(raw)
}
// EncodeRerankRequest 编码重排序请求
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
return encodeRerankRequest(req, provider)
}
// DecodeRerankResponse 解码重排序响应
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
return decodeRerankResponse(raw)
}
// EncodeRerankResponse 编码重排序响应
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return encodeRerankResponse(resp)
}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel)
}
return json.Marshal(m)
default:
return body, nil
}
}