package anthropic import ( "fmt" "testing" "nex/backend/internal/conversion/canonical" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDecodeRequest_Basic(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "messages": [ {"role": "user", "content": "你好"} ] }`) req, err := decodeRequest(body) require.NoError(t, err) assert.Equal(t, "claude-3", req.Model) assert.Len(t, req.Messages, 1) assert.Equal(t, canonical.RoleUser, req.Messages[0].Role) assert.NotNil(t, req.Parameters.MaxTokens) assert.Equal(t, 1024, *req.Parameters.MaxTokens) } func TestDecodeRequest_System(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "system": "你是助手", "messages": [ {"role": "user", "content": "你好"} ] }`) req, err := decodeRequest(body) require.NoError(t, err) assert.Equal(t, "你是助手", req.System) } func TestDecodeRequest_SystemBlocks(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "system": [{"text": "指令1"}, {"text": "指令2"}], "messages": [ {"role": "user", "content": "你好"} ] }`) req, err := decodeRequest(body) require.NoError(t, err) blocks, ok := req.System.([]canonical.SystemBlock) require.True(t, ok) assert.Len(t, blocks, 2) assert.Equal(t, "指令1", blocks[0].Text) } func TestDecodeRequest_ToolResultSplit(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "messages": [ { "role": "user", "content": [ {"type": "text", "text": "查询天气"}, {"type": "tool_result", "tool_use_id": "tool_1", "content": "晴天"} ] } ] }`) req, err := decodeRequest(body) require.NoError(t, err) // 用户消息中的 tool_result 应被拆分为独立的 tool 消息 assert.Len(t, req.Messages, 2) assert.Equal(t, canonical.RoleUser, req.Messages[0].Role) assert.Equal(t, canonical.RoleTool, req.Messages[1].Role) } func TestDecodeRequest_MissingModel(t *testing.T) { body := []byte(`{"max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}]}`) _, err := decodeRequest(body) require.Error(t, err) assert.Contains(t, err.Error(), "INVALID_INPUT") } func TestDecodeRequest_MissingMessages(t *testing.T) { body := []byte(`{"model": "claude-3", "max_tokens": 1024}`) _, err := decodeRequest(body) require.Error(t, err) assert.Contains(t, err.Error(), "INVALID_INPUT") } func TestDecodeResponse_Basic(t *testing.T) { body := []byte(`{ "id": "msg_123", "type": "message", "role": "assistant", "model": "claude-3", "content": [{"type": "text", "text": "你好"}], "stop_reason": "end_turn", "usage": {"input_tokens": 10, "output_tokens": 5} }`) resp, err := decodeResponse(body) require.NoError(t, err) assert.Equal(t, "msg_123", resp.ID) assert.Equal(t, "claude-3", resp.Model) assert.Len(t, resp.Content, 1) assert.Equal(t, "你好", resp.Content[0].Text) assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason) assert.Equal(t, 10, resp.Usage.InputTokens) } func TestDecodeResponse_Thinking(t *testing.T) { body := []byte(`{ "id": "msg_456", "type": "message", "role": "assistant", "model": "claude-3", "content": [ {"type": "thinking", "thinking": "思考过程"}, {"type": "text", "text": "回答"} ], "stop_reason": "end_turn", "usage": {"input_tokens": 10, "output_tokens": 20} }`) resp, err := decodeResponse(body) require.NoError(t, err) assert.Len(t, resp.Content, 2) assert.Equal(t, "thinking", resp.Content[0].Type) assert.Equal(t, "思考过程", resp.Content[0].Thinking) assert.Equal(t, "text", resp.Content[1].Type) assert.Equal(t, "回答", resp.Content[1].Text) } func TestDecodeModelsResponse(t *testing.T) { body := []byte(`{ "data": [ {"id": "claude-3-opus", "type": "model", "display_name": "Claude 3 Opus", "created_at": "2024-01-15T00:00:00Z"}, {"id": "claude-3-sonnet", "type": "model", "created_at": "2024-02-01T00:00:00Z"} ], "has_more": false }`) list, err := decodeModelsResponse(body) require.NoError(t, err) assert.Len(t, list.Models, 2) assert.Equal(t, "claude-3-opus", list.Models[0].ID) assert.Equal(t, "Claude 3 Opus", list.Models[0].Name) // created_at RFC3339 → Unix assert.NotEqual(t, int64(0), list.Models[0].Created) // 无 display_name 时使用 ID assert.Equal(t, "claude-3-sonnet", list.Models[1].Name) } func TestDecodeRequest_InvalidJSON(t *testing.T) { _, err := decodeRequest([]byte(`invalid json`)) require.Error(t, err) assert.Contains(t, err.Error(), "JSON_PARSE_ERROR") } func TestDecodeRequest_Thinking(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}], "thinking": {"type": "enabled", "budget_tokens": 5000} }`) req, err := decodeRequest(body) require.NoError(t, err) require.NotNil(t, req.Thinking) assert.Equal(t, "enabled", req.Thinking.Type) require.NotNil(t, req.Thinking.BudgetTokens) assert.Equal(t, 5000, *req.Thinking.BudgetTokens) } func TestDecodeRequest_ThinkingAdaptive(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}], "thinking": {"type": "adaptive"} }`) req, err := decodeRequest(body) require.NoError(t, err) require.NotNil(t, req.Thinking) assert.Equal(t, "adaptive", req.Thinking.Type) } func TestDecodeRequest_OutputConfig(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}], "output_config": { "format": { "type": "json_schema", "schema": {"type": "object", "properties": {"name": {"type": "string"}}} } } }`) req, err := decodeRequest(body) require.NoError(t, err) require.NotNil(t, req.OutputFormat) assert.Equal(t, "json_schema", req.OutputFormat.Type) assert.NotNil(t, req.OutputFormat.Schema) } func TestDecodeRequest_DisableParallelToolUse(t *testing.T) { body := []byte(`{ "model": "claude-3", "max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}], "disable_parallel_tool_use": true }`) req, err := decodeRequest(body) require.NoError(t, err) require.NotNil(t, req.ParallelToolUse) assert.False(t, *req.ParallelToolUse) } func TestDecodeResponse_ToolUse(t *testing.T) { body := []byte(`{ "id": "msg_tool", "type": "message", "role": "assistant", "model": "claude-3", "content": [ {"type": "tool_use", "id": "tool_1", "name": "search", "input": {"q": "test"}} ], "stop_reason": "tool_use", "usage": {"input_tokens": 10, "output_tokens": 5} }`) resp, err := decodeResponse(body) require.NoError(t, err) require.Len(t, resp.Content, 1) assert.Equal(t, "tool_use", resp.Content[0].Type) assert.Equal(t, "tool_1", resp.Content[0].ID) assert.Equal(t, "search", resp.Content[0].Name) assert.NotNil(t, resp.Content[0].Input) } func TestDecodeResponse_RedactedThinking(t *testing.T) { body := []byte(`{ "id": "msg_redacted", "type": "message", "role": "assistant", "model": "claude-3", "content": [ {"type": "redacted_thinking", "data": "..."}, {"type": "text", "text": "回答"} ], "stop_reason": "end_turn", "usage": {"input_tokens": 10, "output_tokens": 5} }`) resp, err := decodeResponse(body) require.NoError(t, err) assert.Len(t, resp.Content, 1) assert.Equal(t, "text", resp.Content[0].Type) assert.Equal(t, "回答", resp.Content[0].Text) } func TestDecodeResponse_StopReasons(t *testing.T) { tests := []struct { name string reason string want canonical.StopReason }{ {"end_turn→end_turn", "end_turn", canonical.StopReasonEndTurn}, {"max_tokens→max_tokens", "max_tokens", canonical.StopReasonMaxTokens}, {"tool_use→tool_use", "tool_use", canonical.StopReasonToolUse}, {"stop_sequence→stop_sequence", "stop_sequence", canonical.StopReasonStopSequence}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { body := []byte(fmt.Sprintf(`{ "id": "msg-1", "type": "message", "role": "assistant", "model": "claude-3", "content": [{"type": "text", "text": "ok"}], "stop_reason": "%s", "usage": {"input_tokens": 1, "output_tokens": 1} }`, tt.reason)) resp, err := decodeResponse(body) require.NoError(t, err) require.NotNil(t, resp.StopReason) assert.Equal(t, tt.want, *resp.StopReason) }) } } func TestDecodeResponse_Usage(t *testing.T) { body := []byte(`{ "id": "msg_usage", "type": "message", "role": "assistant", "model": "claude-3", "content": [{"type": "text", "text": "ok"}], "stop_reason": "end_turn", "usage": { "input_tokens": 100, "output_tokens": 50, "cache_read_input_tokens": 30 } }`) resp, err := decodeResponse(body) require.NoError(t, err) assert.Equal(t, 100, resp.Usage.InputTokens) assert.Equal(t, 50, resp.Usage.OutputTokens) require.NotNil(t, resp.Usage.CacheReadTokens) assert.Equal(t, 30, *resp.Usage.CacheReadTokens) }