实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。 核心变更: - 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID - 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束 - Repository 层:FindByProviderAndModelName、ListEnabled 方法 - Service 层:联合唯一校验、provider ID 字符集校验 - Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法 - Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合 - 新增 error-responses、unified-model-id 规范 测试覆盖: - 单元测试:modelid、conversion、handler、service、repository - 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换 - 迁移测试:UUID 主键、UNIQUE 约束、级联删除 OpenSpec: - 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id - 同步 11 个 delta specs 到 main specs - 新增 error-responses、unified-model-id 规范文件
297 lines
9.3 KiB
Go
297 lines
9.3 KiB
Go
package openai
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"nex/backend/internal/conversion"
|
||
"nex/backend/internal/conversion/canonical"
|
||
)
|
||
|
||
// Adapter OpenAI 协议适配器
|
||
type Adapter struct{}
|
||
|
||
// NewAdapter 创建 OpenAI 适配器
|
||
func NewAdapter() *Adapter {
|
||
return &Adapter{}
|
||
}
|
||
|
||
// ProtocolName 返回协议名称
|
||
func (a *Adapter) ProtocolName() string { return "openai" }
|
||
|
||
// ProtocolVersion 返回协议版本
|
||
func (a *Adapter) ProtocolVersion() string { return "" }
|
||
|
||
// SupportsPassthrough 支持同协议透传
|
||
func (a *Adapter) SupportsPassthrough() bool { return true }
|
||
|
||
// DetectInterfaceType 根据路径检测接口类型
|
||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||
switch {
|
||
case nativePath == "/v1/chat/completions":
|
||
return conversion.InterfaceTypeChat
|
||
case nativePath == "/v1/models":
|
||
return conversion.InterfaceTypeModels
|
||
case isModelInfoPath(nativePath):
|
||
return conversion.InterfaceTypeModelInfo
|
||
case nativePath == "/v1/embeddings":
|
||
return conversion.InterfaceTypeEmbeddings
|
||
case nativePath == "/v1/rerank":
|
||
return conversion.InterfaceTypeRerank
|
||
default:
|
||
return conversion.InterfaceTypePassthrough
|
||
}
|
||
}
|
||
|
||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||
func isModelInfoPath(path string) bool {
|
||
if !strings.HasPrefix(path, "/v1/models/") {
|
||
return false
|
||
}
|
||
suffix := path[len("/v1/models/"):]
|
||
return suffix != ""
|
||
}
|
||
|
||
// BuildUrl 根据接口类型构建 URL
|
||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||
switch interfaceType {
|
||
case conversion.InterfaceTypeChat:
|
||
return "/v1/chat/completions"
|
||
case conversion.InterfaceTypeModels:
|
||
return "/v1/models"
|
||
case conversion.InterfaceTypeEmbeddings:
|
||
return "/v1/embeddings"
|
||
case conversion.InterfaceTypeRerank:
|
||
return "/v1/rerank"
|
||
default:
|
||
return nativePath
|
||
}
|
||
}
|
||
|
||
// BuildHeaders 构建请求头
|
||
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
|
||
headers := map[string]string{
|
||
"Authorization": "Bearer " + provider.APIKey,
|
||
"Content-Type": "application/json",
|
||
}
|
||
if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" {
|
||
headers["OpenAI-Organization"] = org
|
||
}
|
||
return headers
|
||
}
|
||
|
||
// SupportsInterface 检查是否支持接口类型
|
||
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
|
||
switch interfaceType {
|
||
case conversion.InterfaceTypeChat,
|
||
conversion.InterfaceTypeModels,
|
||
conversion.InterfaceTypeModelInfo,
|
||
conversion.InterfaceTypeEmbeddings,
|
||
conversion.InterfaceTypeRerank:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// DecodeRequest 解码请求
|
||
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||
return decodeRequest(raw)
|
||
}
|
||
|
||
// EncodeRequest 编码请求
|
||
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||
return encodeRequest(req, provider)
|
||
}
|
||
|
||
// DecodeResponse 解码响应
|
||
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||
return decodeResponse(raw)
|
||
}
|
||
|
||
// EncodeResponse 编码响应
|
||
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||
return encodeResponse(resp)
|
||
}
|
||
|
||
// CreateStreamDecoder 创建流式解码器
|
||
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
|
||
return NewStreamDecoder()
|
||
}
|
||
|
||
// CreateStreamEncoder 创建流式编码器
|
||
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
|
||
return NewStreamEncoder()
|
||
}
|
||
|
||
// EncodeError 编码错误
|
||
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||
errType := mapErrorCode(err.Code)
|
||
statusCode := 500
|
||
|
||
errMsg := ErrorResponse{
|
||
Error: ErrorDetail{
|
||
Message: err.Message,
|
||
Type: errType,
|
||
Param: nil,
|
||
Code: string(err.Code),
|
||
},
|
||
}
|
||
body, _ := json.Marshal(errMsg)
|
||
return body, statusCode
|
||
}
|
||
|
||
// mapErrorCode 映射错误码到 OpenAI 错误类型
|
||
func mapErrorCode(code conversion.ErrorCode) string {
|
||
switch code {
|
||
case conversion.ErrorCodeInvalidInput,
|
||
conversion.ErrorCodeMissingRequiredField,
|
||
conversion.ErrorCodeIncompatibleFeature,
|
||
conversion.ErrorCodeToolCallParseError,
|
||
conversion.ErrorCodeJSONParseError,
|
||
conversion.ErrorCodeProtocolConstraint,
|
||
conversion.ErrorCodeFieldMappingFailure:
|
||
return "invalid_request_error"
|
||
default:
|
||
return "server_error"
|
||
}
|
||
}
|
||
|
||
// DecodeModelsResponse 解码模型列表响应
|
||
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||
return decodeModelsResponse(raw)
|
||
}
|
||
|
||
// EncodeModelsResponse 编码模型列表响应
|
||
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||
return encodeModelsResponse(list)
|
||
}
|
||
|
||
// DecodeModelInfoResponse 解码模型详情响应
|
||
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||
return decodeModelInfoResponse(raw)
|
||
}
|
||
|
||
// EncodeModelInfoResponse 编码模型详情响应
|
||
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||
return encodeModelInfoResponse(info)
|
||
}
|
||
|
||
// DecodeEmbeddingRequest 解码嵌入请求
|
||
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||
return decodeEmbeddingRequest(raw)
|
||
}
|
||
|
||
// EncodeEmbeddingRequest 编码嵌入请求
|
||
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||
return encodeEmbeddingRequest(req, provider)
|
||
}
|
||
|
||
// DecodeEmbeddingResponse 解码嵌入响应
|
||
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||
return decodeEmbeddingResponse(raw)
|
||
}
|
||
|
||
// EncodeEmbeddingResponse 编码嵌入响应
|
||
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||
return encodeEmbeddingResponse(resp)
|
||
}
|
||
|
||
// DecodeRerankRequest 解码重排序请求
|
||
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||
return decodeRerankRequest(raw)
|
||
}
|
||
|
||
// EncodeRerankRequest 编码重排序请求
|
||
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||
return encodeRerankRequest(req, provider)
|
||
}
|
||
|
||
// DecodeRerankResponse 解码重排序响应
|
||
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||
return decodeRerankResponse(raw)
|
||
}
|
||
|
||
// EncodeRerankResponse 编码重排序响应
|
||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||
return encodeRerankResponse(resp)
|
||
}
|
||
|
||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||
}
|
||
suffix := nativePath[len("/v1/models/"):]
|
||
if suffix == "" {
|
||
return "", fmt.Errorf("路径缺少模型 ID")
|
||
}
|
||
return suffix, nil
|
||
}
|
||
|
||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||
var m map[string]json.RawMessage
|
||
if err := json.Unmarshal(body, &m); err != nil {
|
||
return "", nil, err
|
||
}
|
||
|
||
switch ifaceType {
|
||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||
raw, exists := m["model"]
|
||
if !exists {
|
||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||
}
|
||
var current string
|
||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||
return "", nil, err
|
||
}
|
||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||
m["model"], _ = json.Marshal(newModel)
|
||
return json.Marshal(m)
|
||
}
|
||
return current, rewriteFunc, nil
|
||
default:
|
||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||
}
|
||
}
|
||
|
||
// ExtractModelName 从请求体中提取 model 值
|
||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||
return model, err
|
||
}
|
||
|
||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return rewriteFunc(newModel)
|
||
}
|
||
|
||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||
var m map[string]json.RawMessage
|
||
if err := json.Unmarshal(body, &m); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
switch ifaceType {
|
||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||
m["model"], _ = json.Marshal(newModel)
|
||
return json.Marshal(m)
|
||
case conversion.InterfaceTypeRerank:
|
||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||
if _, exists := m["model"]; exists {
|
||
m["model"], _ = json.Marshal(newModel)
|
||
}
|
||
return json.Marshal(m)
|
||
default:
|
||
return body, nil
|
||
}
|
||
}
|