package openai import ( "encoding/json" "testing" "nex/backend/internal/conversion" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestAdapter_ProtocolName(t *testing.T) { a := NewAdapter() assert.Equal(t, "openai", a.ProtocolName()) } func TestAdapter_SupportsPassthrough(t *testing.T) { a := NewAdapter() assert.True(t, a.SupportsPassthrough()) } func TestAdapter_DetectInterfaceType(t *testing.T) { a := NewAdapter() tests := []struct { name string path string expected conversion.InterfaceType }{ {"聊天补全", "/chat/completions", conversion.InterfaceTypeChat}, {"模型列表", "/models", conversion.InterfaceTypeModels}, {"模型详情", "/models/gpt-4", conversion.InterfaceTypeModelInfo}, {"嵌入接口", "/embeddings", conversion.InterfaceTypeEmbeddings}, {"重排序接口", "/rerank", conversion.InterfaceTypeRerank}, {"未知路径", "/unknown", conversion.InterfaceTypePassthrough}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := a.DetectInterfaceType(tt.path) assert.Equal(t, tt.expected, result) }) } } func TestAdapter_BuildUrl(t *testing.T) { a := NewAdapter() tests := []struct { name string nativePath string interfaceType conversion.InterfaceType expected string }{ {"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"}, {"模型", "/models", conversion.InterfaceTypeModels, "/models"}, {"嵌入", "/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"}, {"重排序", "/rerank", conversion.InterfaceTypeRerank, "/rerank"}, {"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := a.BuildUrl(tt.nativePath, tt.interfaceType) assert.Equal(t, tt.expected, result) }) } } func TestAdapter_BuildHeaders(t *testing.T) { a := NewAdapter() t.Run("基本头", func(t *testing.T) { provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4") headers := a.BuildHeaders(provider) assert.Equal(t, "Bearer sk-test123", headers["Authorization"]) assert.Equal(t, "application/json", headers["Content-Type"]) _, hasOrg := headers["OpenAI-Organization"] assert.False(t, hasOrg) }) t.Run("带组织", func(t *testing.T) { provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4") provider.AdapterConfig["organization"] = "org-abc" headers := a.BuildHeaders(provider) assert.Equal(t, "org-abc", headers["OpenAI-Organization"]) }) } func TestAdapter_SupportsInterface(t *testing.T) { a := NewAdapter() tests := []struct { name string interfaceType conversion.InterfaceType expected bool }{ {"聊天", conversion.InterfaceTypeChat, true}, {"模型", conversion.InterfaceTypeModels, true}, {"模型详情", conversion.InterfaceTypeModelInfo, true}, {"嵌入", conversion.InterfaceTypeEmbeddings, true}, {"重排序", conversion.InterfaceTypeRerank, true}, {"透传", conversion.InterfaceTypePassthrough, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := a.SupportsInterface(tt.interfaceType) assert.Equal(t, tt.expected, result) }) } } func TestIsModelInfoPath(t *testing.T) { tests := []struct { name string path string expected bool }{ {"model_info", "/models/gpt-4", true}, {"model_info_with_dots", "/models/gpt-4.1-preview", true}, {"models_list", "/models", false}, {"nested_path", "/models/gpt-4/versions", true}, {"empty_suffix", "/models/", false}, {"unrelated", "/chat/completions", false}, {"partial_prefix", "/model", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.Equal(t, tt.expected, isModelInfoPath(tt.path)) }) } } func TestAdapter_EncodeError_InvalidInput(t *testing.T) { a := NewAdapter() convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效") body, statusCode := a.EncodeError(convErr) require.Equal(t, 500, statusCode) var resp ErrorResponse require.NoError(t, json.Unmarshal(body, &resp)) assert.Equal(t, "参数无效", resp.Error.Message) assert.Equal(t, "invalid_request_error", resp.Error.Type) } func TestAdapter_EncodeError_ServerError(t *testing.T) { a := NewAdapter() convErr := conversion.NewConversionError(conversion.ErrorCodeStreamStateError, "流状态错误") body, statusCode := a.EncodeError(convErr) require.Equal(t, 500, statusCode) var resp ErrorResponse require.NoError(t, json.Unmarshal(body, &resp)) assert.Equal(t, "server_error", resp.Error.Type) assert.Equal(t, "流状态错误", resp.Error.Message) }