package anthropic import ( "encoding/json" "testing" "nex/backend/internal/conversion" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- // ExtractUnifiedModelID // --------------------------------------------------------------------------- func TestExtractUnifiedModelID(t *testing.T) { a := NewAdapter() t.Run("standard_path", func(t *testing.T) { id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3") require.NoError(t, err) assert.Equal(t, "anthropic/claude-3", id) }) t.Run("multi_segment_path", func(t *testing.T) { id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model") require.NoError(t, err) assert.Equal(t, "some/deep/nested/model", id) }) t.Run("single_segment", func(t *testing.T) { id, err := a.ExtractUnifiedModelID("/v1/models/claude-3") require.NoError(t, err) assert.Equal(t, "claude-3", id) }) t.Run("non_model_path", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/v1/messages") require.Error(t, err) }) t.Run("empty_suffix", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/v1/models/") require.Error(t, err) }) t.Run("models_list_no_slash", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/v1/models") require.Error(t, err) }) t.Run("unrelated_path", func(t *testing.T) { _, err := a.ExtractUnifiedModelID("/v1/other") require.Error(t, err) }) } // --------------------------------------------------------------------------- // ExtractModelName (Chat only for Anthropic) // --------------------------------------------------------------------------- func TestExtractModelName(t *testing.T) { a := NewAdapter() t.Run("chat", func(t *testing.T) { body := []byte(`{"model":"anthropic/claude-3","messages":[]}`) model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, "anthropic/claude-3", model) }) t.Run("chat_with_max_tokens", func(t *testing.T) { body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`) model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, "anthropic/claude-3-opus", model) }) t.Run("no_model_field", func(t *testing.T) { body := []byte(`{"messages":[]}`) _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("invalid_json", func(t *testing.T) { body := []byte(`{invalid}`) _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("unsupported_interface_type_embedding", func(t *testing.T) { body := []byte(`{"model":"anthropic/claude-3"}`) _, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings) require.Error(t, err) }) t.Run("unsupported_interface_type_rerank", func(t *testing.T) { body := []byte(`{"model":"anthropic/claude-3"}`) _, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank) require.Error(t, err) }) } // --------------------------------------------------------------------------- // RewriteRequestModelName (Chat only for Anthropic) // --------------------------------------------------------------------------- func TestRewriteRequestModelName(t *testing.T) { a := NewAdapter() t.Run("chat", func(t *testing.T) { body := []byte(`{"model":"anthropic/claude-3","messages":[]}`) rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "claude-3", m["model"]) msgs, ok := m["messages"] require.True(t, ok) msgsArr, ok := msgs.([]interface{}) require.True(t, ok) assert.Len(t, msgsArr, 0) }) t.Run("preserves_unknown_fields", func(t *testing.T) { body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`) rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "claude-3", m["model"]) assert.Equal(t, 0.7, m["temperature"]) // max_tokens is encoded as float in JSON numbers maxTokens, ok := m["max_tokens"] require.True(t, ok) assert.Equal(t, float64(1024), maxTokens) }) t.Run("no_model_field", func(t *testing.T) { body := []byte(`{"messages":[]}`) _, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("invalid_json", func(t *testing.T) { body := []byte(`{invalid}`) _, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) require.Error(t, err) }) t.Run("unsupported_interface_type", func(t *testing.T) { body := []byte(`{"model":"anthropic/claude-3"}`) _, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings) require.Error(t, err) }) } // --------------------------------------------------------------------------- // RewriteResponseModelName (Chat only for Anthropic) // --------------------------------------------------------------------------- func TestRewriteResponseModelName(t *testing.T) { a := NewAdapter() t.Run("chat_existing_model", func(t *testing.T) { body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`) rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "anthropic/claude-3", m["model"]) // other fields preserved _, hasContent := m["content"] assert.True(t, hasContent) assert.Equal(t, "end_turn", m["stop_reason"]) }) t.Run("chat_without_model_field_adds_it", func(t *testing.T) { body := []byte(`{"content":[],"stop_reason":"end_turn"}`) rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat) require.NoError(t, err) var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, "anthropic/claude-3", m["model"]) }) t.Run("passthrough_returns_body_unchanged", func(t *testing.T) { body := []byte(`{"model":"claude-3"}`) rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough) require.NoError(t, err) assert.Equal(t, string(body), string(rewritten)) }) t.Run("invalid_json", func(t *testing.T) { body := []byte(`{invalid}`) _, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat) require.Error(t, err) }) } // --------------------------------------------------------------------------- // ExtractModelName and RewriteRequest consistency // --------------------------------------------------------------------------- func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) { a := NewAdapter() t.Run("chat_round_trip", func(t *testing.T) { original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`) // Extract the unified model ID from the body extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, "anthropic/claude-3", extracted) // Rewrite to the native model name rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat) require.NoError(t, err) // Extract again from the rewritten body to verify the same location was targeted afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat) require.NoError(t, err) assert.Equal(t, "claude-3", afterRewrite) // Verify other fields are preserved var m map[string]interface{} require.NoError(t, json.Unmarshal(rewritten, &m)) assert.Equal(t, float64(1024), m["max_tokens"]) }) } // --------------------------------------------------------------------------- // isModelInfoPath (additional unified model ID cases) // --------------------------------------------------------------------------- func TestIsModelInfoPath_UnifiedModelID(t *testing.T) { tests := []struct { name string path string expected bool }{ {"simple_model_id", "/v1/models/claude-3", true}, {"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true}, {"models_list", "/v1/models", false}, {"models_list_trailing_slash", "/v1/models/", false}, {"messages_path", "/v1/messages", false}, {"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.Equal(t, tt.expected, isModelInfoPath(tt.path)) }) } }