package openai import ( "encoding/json" "testing" "nex/backend/internal/conversion" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestAdapter_ProtocolName(t *testing.T) { a := NewAdapter() assert.Equal(t, "openai", a.ProtocolName()) } func TestAdapter_SupportsPassthrough(t *testing.T) { a := NewAdapter() assert.True(t, a.SupportsPassthrough()) } func TestAdapter_DetectInterfaceType(t *testing.T) { a := NewAdapter() tests := []struct { name string path string expected conversion.InterfaceType }{ {"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat}, {"模型列表", "/v1/models", conversion.InterfaceTypeModels}, {"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo}, {"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings}, {"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank}, {"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := a.DetectInterfaceType(tt.path) assert.Equal(t, tt.expected, result) }) } } func TestAdapter_BuildUrl(t *testing.T) { a := NewAdapter() tests := []struct { name string nativePath string interfaceType conversion.InterfaceType expected string }{ {"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"}, {"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"}, {"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"}, {"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"}, {"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := a.BuildUrl(tt.nativePath, tt.interfaceType) assert.Equal(t, tt.expected, result) }) } } func TestAdapter_BuildHeaders(t *testing.T) { a := NewAdapter() t.Run("基本头", func(t *testing.T) { provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4") headers := a.BuildHeaders(provider) assert.Equal(t, "Bearer sk-test123", headers["Authorization"]) assert.Equal(t, "application/json", headers["Content-Type"]) _, hasOrg := headers["OpenAI-Organization"] assert.False(t, hasOrg) }) t.Run("带组织", func(t *testing.T) { provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4") provider.AdapterConfig["organization"] = "org-abc" headers := a.BuildHeaders(provider) assert.Equal(t, "org-abc", headers["OpenAI-Organization"]) }) } func TestAdapter_SupportsInterface(t *testing.T) { a := NewAdapter() tests := []struct { name string interfaceType conversion.InterfaceType expected bool }{ {"聊天", conversion.InterfaceTypeChat, true}, {"模型", conversion.InterfaceTypeModels, true}, {"模型详情", conversion.InterfaceTypeModelInfo, true}, {"嵌入", conversion.InterfaceTypeEmbeddings, true}, {"重排序", conversion.InterfaceTypeRerank, true}, {"透传", conversion.InterfaceTypePassthrough, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := a.SupportsInterface(tt.interfaceType) assert.Equal(t, tt.expected, result) }) } } func TestIsModelInfoPath(t *testing.T) { tests := []struct { name string path string expected bool }{ {"model_info", "/v1/models/gpt-4", true}, {"model_info_with_dots", "/v1/models/gpt-4.1-preview", true}, {"models_list", "/v1/models", false}, {"nested_path", "/v1/models/gpt-4/versions", false}, {"empty_suffix", "/v1/models/", false}, {"unrelated", "/v1/chat/completions", false}, {"partial_prefix", "/v1/model", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.Equal(t, tt.expected, isModelInfoPath(tt.path)) }) } } func TestAdapter_EncodeError_InvalidInput(t *testing.T) { a := NewAdapter() convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效") body, statusCode := a.EncodeError(convErr) require.Equal(t, 500, statusCode) var resp ErrorResponse require.NoError(t, json.Unmarshal(body, &resp)) assert.Equal(t, "参数无效", resp.Error.Message) assert.Equal(t, "invalid_request_error", resp.Error.Type) } func TestAdapter_EncodeError_ServerError(t *testing.T) { a := NewAdapter() convErr := conversion.NewConversionError(conversion.ErrorCodeStreamStateError, "流状态错误") body, statusCode := a.EncodeError(convErr) require.Equal(t, 500, statusCode) var resp ErrorResponse require.NoError(t, json.Unmarshal(body, &resp)) assert.Equal(t, "server_error", resp.Error.Type) assert.Equal(t, "流状态错误", resp.Error.Message) }