1
0
Files
nex/backend/internal/conversion/openai/adapter_test.go
lanyuanxiaoyao 395887667d feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
2026-04-21 18:14:10 +08:00

162 lines
4.7 KiB
Go

package openai
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAdapter_ProtocolName(t *testing.T) {
a := NewAdapter()
assert.Equal(t, "openai", a.ProtocolName())
}
func TestAdapter_SupportsPassthrough(t *testing.T) {
a := NewAdapter()
assert.True(t, a.SupportsPassthrough())
}
func TestAdapter_DetectInterfaceType(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
path string
expected conversion.InterfaceType
}{
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
{"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo},
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := a.DetectInterfaceType(tt.path)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
nativePath string
interfaceType conversion.InterfaceType
expected string
}{
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"},
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"},
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := a.BuildUrl(tt.nativePath, tt.interfaceType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAdapter_BuildHeaders(t *testing.T) {
a := NewAdapter()
t.Run("基本头", func(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
headers := a.BuildHeaders(provider)
assert.Equal(t, "Bearer sk-test123", headers["Authorization"])
assert.Equal(t, "application/json", headers["Content-Type"])
_, hasOrg := headers["OpenAI-Organization"]
assert.False(t, hasOrg)
})
t.Run("带组织", func(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
provider.AdapterConfig["organization"] = "org-abc"
headers := a.BuildHeaders(provider)
assert.Equal(t, "org-abc", headers["OpenAI-Organization"])
})
}
func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
interfaceType conversion.InterfaceType
expected bool
}{
{"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true},
{"模型详情", conversion.InterfaceTypeModelInfo, true},
{"嵌入", conversion.InterfaceTypeEmbeddings, true},
{"重排序", conversion.InterfaceTypeRerank, true},
{"透传", conversion.InterfaceTypePassthrough, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := a.SupportsInterface(tt.interfaceType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsModelInfoPath(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"model_info", "/v1/models/gpt-4", true},
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
{"models_list", "/v1/models", false},
{"nested_path", "/v1/models/gpt-4/versions", true},
{"empty_suffix", "/v1/models/", false},
{"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/v1/model", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
a := NewAdapter()
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
body, statusCode := a.EncodeError(convErr)
require.Equal(t, 500, statusCode)
var resp ErrorResponse
require.NoError(t, json.Unmarshal(body, &resp))
assert.Equal(t, "参数无效", resp.Error.Message)
assert.Equal(t, "invalid_request_error", resp.Error.Type)
}
func TestAdapter_EncodeError_ServerError(t *testing.T) {
a := NewAdapter()
convErr := conversion.NewConversionError(conversion.ErrorCodeStreamStateError, "流状态错误")
body, statusCode := a.EncodeError(convErr)
require.Equal(t, 500, statusCode)
var resp ErrorResponse
require.NoError(t, json.Unmarshal(body, &resp))
assert.Equal(t, "server_error", resp.Error.Type)
assert.Equal(t, "流状态错误", resp.Error.Message)
}