1
0
Files
nex/backend/internal/conversion/anthropic/stream_encoder_test.go
lanyuanxiaoyao 38a2555c7b fix: Anthropic 流式编码器补全 message_start/message_delta 必填字段
跨协议流式转换时,Anthropic 客户端 Zod 校验因 SSE 事件缺少必填字段报错。
由 Anthropic encoder 层(而非 OpenAI decoder 层)负责补全协议默认值,保持权责分离。

- encodeMessageStart 补全 type/content/stop_reason/stop_sequence,usage nil 时输出零值
- encodeMessageDelta usage nil 时输出零值
- 更新相关测试覆盖新增行为
2026-04-26 23:27:34 +08:00

299 lines
8.0 KiB
Go

package anthropic
import (
"encoding/json"
"strings"
"testing"
"nex/backend/internal/conversion/canonical"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStreamEncoder_MessageStart(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewMessageStartEvent("msg_1", "claude-3")
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
assert.Contains(t, s, "data: ")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "msg_1", msg["id"])
assert.Equal(t, "message", msg["type"])
assert.Equal(t, "assistant", msg["role"])
assert.Equal(t, []any{}, msg["content"])
assert.Equal(t, "claude-3", msg["model"])
assert.Nil(t, msg["stop_reason"])
assert.Nil(t, msg["stop_sequence"])
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(0), usage["input_tokens"])
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(50), usage["output_tokens"])
}
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: content_block_delta\n"))
assert.Contains(t, s, "你好")
// 验证 JSON 格式
lines := strings.Split(s, "\n")
var dataLine string
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
dataLine = strings.TrimPrefix(l, "data: ")
break
}
}
var payload map[string]any
require.NoError(t, json.Unmarshal([]byte(dataLine), &payload))
assert.Equal(t, "content_block_delta", payload["type"])
}
func TestStreamEncoder_MessageStop(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewMessageStopEvent()
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: message_stop\n"))
}
func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: content_block_start\n"))
assert.Contains(t, s, "data: ")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"])
}
func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewContentBlockStartEvent(1, canonical.StreamContentBlock{
Type: "tool_use",
ID: "toolu_1",
Name: "search",
})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.Contains(t, s, "toolu_1")
assert.Contains(t, s, "search")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"])
}
func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "thinking", Thinking: ""})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.Contains(t, s, "thinking")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"])
}
func TestStreamEncoder_ContentBlockStop(t *testing.T) {
e := NewStreamEncoder()
idx := 2
event := canonical.CanonicalStreamEvent{
Type: canonical.EventContentBlockStop,
Index: &idx,
}
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: content_block_stop\n"))
assert.Contains(t, s, "content_block_stop")
}
func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
e := NewStreamEncoder()
sr := canonical.StopReasonEndTurn
event := canonical.CanonicalStreamEvent{
Type: canonical.EventMessageDelta,
StopReason: &sr,
}
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.Contains(t, s, "stop_reason")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"])
usage, oku := payload["usage"].(map[string]any)
require.True(t, oku, "message_delta SHALL always include usage")
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
e := NewStreamEncoder()
usage := canonical.CanonicalUsage{OutputTokens: 88}
event := canonical.CanonicalStreamEvent{
Type: canonical.EventMessageDelta,
Usage: &usage,
}
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.Contains(t, s, "output_tokens")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"])
}
func TestStreamEncoder_Ping(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewPingEvent()
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: ping\n"))
assert.Contains(t, s, "ping")
}
func TestStreamEncoder_Error(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewErrorEvent("overloaded_error", "服务过载")
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: error\n"))
assert.Contains(t, s, "overloaded_error")
assert.Contains(t, s, "服务过载")
}
func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) {
e := NewStreamEncoder()
chunks := e.Flush()
assert.Nil(t, chunks)
}
func TestStreamEncoder_UnknownEvent_ReturnsNil(t *testing.T) {
e := NewStreamEncoder()
event := canonical.CanonicalStreamEvent{Type: "unknown_event_type"}
chunks := e.EncodeEvent(event)
assert.Nil(t, chunks)
}