package openai import ( "encoding/json" "testing" "nex/backend/internal/conversion" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- // ExtractUnifiedModelID // --------------------------------------------------------------------------- func TestExtractUnifiedModelID(t *testing.T) { a := NewAdapter() t.Run("standard_path", func(t *testing.T) { id, err := a.ExtractUnifiedModelID("/models/openai/gpt-4") require.NoError(t, err) assert.Equal(t, "openai/gpt-4", id) }) t.Run("multi_segment_path", func(t *testing.T) { id, err := a.ExtractUnifiedModelID("/models/azure/accounts/org/models/gpt-4") require.NoError(t, err) assert.Equal(t, "azure/accounts/org/models/gpt-4", id) }) t.Run("single_segment", func(t *testing.T) { id, err := a.ExtractUnifiedModelID("/models/gpt-4") require.NoError(t, err) assert.Equal(t, "gpt-4", id) }) t.Run("non_model_path", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/chat/completions") require.Error(t, err) }) t.Run("empty_suffix", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/models/") require.Error(t, err) }) t.Run("models_list_no_slash", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/models") require.Error(t, err) }) t.Run("unrelated_path", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/other") require.Error(t, err) }) } // --------------------------------------------------------------------------- // ExtractModelName // --------------------------------------------------------------------------- func TestExtractModelName(t *testing.T) { a := NewAdapter() t.Run("chat", func(t *testing.T) { body := []byte(`{"model":"openai/gpt-4","messages":[]}`) model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, "openai/gpt-4", model) }) t.Run("embedding", func(t *testing.T) { body := []byte(`{"model":"openai/text-embedding","input":"hello"}`) model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings) require.NoError(t, err) assert.Equal(t, "openai/text-embedding", model) }) t.Run("rerank", func(t *testing.T) { body := []byte(`{"model":"openai/rerank","query":"test"}`) model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank) require.NoError(t, err) assert.Equal(t, "openai/rerank", model) }) t.Run("no_model_field", func(t *testing.T) { body := []byte(`{"messages":[]}`) _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("invalid_json", func(t *testing.T) { body := []byte(`{invalid}`) _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("unsupported_interface_type", func(t *testing.T) { body := []byte(`{"model":"openai/gpt-4"}`) _, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough) require.Error(t, err) }) } // --------------------------------------------------------------------------- // RewriteRequestModelName // --------------------------------------------------------------------------- func TestRewriteRequestModelName(t *testing.T) { a := NewAdapter() t.Run("chat", func(t *testing.T) { body := []byte(`{"model":"openai/gpt-4","messages":[]}`) rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "gpt-4", m["model"]) // messages field preserved msgs, ok := m["messages"] require.True(t, ok) msgsArr, ok := msgs.([]interface{}) require.True(t, ok) assert.Len(t, msgsArr, 0) }) t.Run("preserves_unknown_fields", func(t *testing.T) { body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`) rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "gpt-4", m["model"]) assert.Equal(t, 0.7, m["temperature"]) }) t.Run("embedding", func(t *testing.T) { body := []byte(`{"model":"openai/text-embedding","input":"hello"}`) rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "text-embedding", m["model"]) assert.Equal(t, "hello", m["input"]) }) t.Run("rerank", func(t *testing.T) { body := []byte(`{"model":"openai/rerank","query":"test"}`) rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "rerank", m["model"]) assert.Equal(t, "test", m["query"]) }) t.Run("no_model_field", func(t *testing.T) { body := []byte(`{"messages":[]}`) _, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("invalid_json", func(t *testing.T) { body := []byte(`{invalid}`) _, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("unsupported_interface_type", func(t *testing.T) { body := []byte(`{"model":"openai/gpt-4"}`) _, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough) require.Error(t, err) }) } // --------------------------------------------------------------------------- // RewriteResponseModelName // --------------------------------------------------------------------------- func TestRewriteResponseModelName(t *testing.T) { a := NewAdapter() t.Run("chat_existing_model", func(t *testing.T) { body := []byte(`{"model":"gpt-4","choices":[]}`) rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "openai/gpt-4", m["model"]) choices, ok := m["choices"] require.True(t, ok) choicesArr, ok := choices.([]interface{}) require.True(t, ok) assert.Len(t, choicesArr, 0) }) t.Run("chat_without_model_field", func(t *testing.T) { body := []byte(`{"choices":[]}`) rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "openai/gpt-4", m["model"]) choices, ok := m["choices"] require.True(t, ok) choicesArr, ok := choices.([]interface{}) require.True(t, ok) assert.Len(t, choicesArr, 0) }) t.Run("rerank_existing_model", func(t *testing.T) { body := []byte(`{"model":"rerank","results":[]}`) rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "openai/rerank", m["model"]) }) t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) { body := []byte(`{"results":[]}`) rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) _, hasModel := m["model"] assert.False(t, hasModel, "rerank response without model field should not have one added") }) t.Run("embedding_existing_model", func(t *testing.T) { body := []byte(`{"model":"text-embedding","data":[]}`) rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "openai/text-embedding", m["model"]) }) t.Run("embedding_without_model_field_adds", func(t *testing.T) { body := []byte(`{"data":[]}`) rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "openai/text-embedding", m["model"]) }) t.Run("passthrough_returns_body_unchanged", func(t *testing.T) { body := []byte(`{"model":"gpt-4"}`) rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough) require.NoError(t, err) assert.Equal(t, string(body), string(rewritten)) }) t.Run("invalid_json", func(t *testing.T) { body := []byte(`{invalid}`) _, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat) require.Error(t, err) }) } // --------------------------------------------------------------------------- // ExtractModelName and RewriteRequest consistency // --------------------------------------------------------------------------- func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) { a := NewAdapter() t.Run("chat_round_trip", func(t *testing.T) { original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`) // Extract the unified model ID from the body extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, "openai/gpt-4", extracted) // Rewrite to the native model name rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat) require.NoError(t, err) // Extract again from the rewritten body to verify the same location was targeted afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, "gpt-4", afterRewrite) // Verify other fields are preserved var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, 0.7, m["temperature"]) }) t.Run("embedding_round_trip", func(t *testing.T) { original := []byte(`{"model":"openai/text-embedding","input":"hello"}`) extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings) require.NoError(t, err) assert.Equal(t, "openai/text-embedding", extracted) rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings) require.NoError(t, err) afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings) require.NoError(t, err) assert.Equal(t, "text-embedding", afterRewrite) }) t.Run("rerank_round_trip", func(t *testing.T) { original := []byte(`{"model":"openai/rerank","query":"test"}`) extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank) require.NoError(t, err) assert.Equal(t, "openai/rerank", extracted) rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank) require.NoError(t, err) afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank) require.NoError(t, err) assert.Equal(t, "rerank", afterRewrite) }) } // --------------------------------------------------------------------------- // isModelInfoPath (additional unified model ID cases) // --------------------------------------------------------------------------- func TestIsModelInfoPath_UnifiedModelID(t *testing.T) { tests := []struct { name string path string expected bool }{ {"simple_model_id", "/models/gpt-4", true}, {"unified_model_id_with_slash", "/models/openai/gpt-4", true}, {"models_list", "/models", false}, {"models_list_trailing_slash", "/models/", false}, {"chat_completions", "/chat/completions", false}, {"deeply_nested", "/models/azure/eastus/deployments/my-dept/models/gpt-4", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.Equal(t, tt.expected, isModelInfoPath(tt.path)) }) } }