refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间 无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化 ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
This commit is contained in:
199
backend/internal/conversion/anthropic/adapter.go
Normal file
199
backend/internal/conversion/anthropic/adapter.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// Adapter Anthropic 协议适配器
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 Anthropic 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
var modelInfoRegex = regexp.MustCompile(`^/v1/models/[^/]+$`)
|
||||
|
||||
// ProtocolName 返回协议名称
|
||||
func (a *Adapter) ProtocolName() string { return "anthropic" }
|
||||
|
||||
// ProtocolVersion 返回协议版本
|
||||
func (a *Adapter) ProtocolVersion() string { return "2023-06-01" }
|
||||
|
||||
// SupportsPassthrough 支持同协议透传
|
||||
func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/v1/messages":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/v1/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case modelInfoRegex.MatchString(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/v1/messages"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/v1/models"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHeaders 构建请求头
|
||||
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
|
||||
headers := map[string]string{
|
||||
"x-api-key": provider.APIKey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if v, ok := provider.AdapterConfig["anthropic_version"].(string); ok && v != "" {
|
||||
headers["anthropic-version"] = v
|
||||
}
|
||||
if betas, ok := provider.AdapterConfig["anthropic_beta"].([]string); ok && len(betas) > 0 {
|
||||
headers["anthropic-beta"] = strings.Join(betas, ",")
|
||||
} else if betas, ok := provider.AdapterConfig["anthropic_beta"].([]any); ok && len(betas) > 0 {
|
||||
var parts []string
|
||||
for _, b := range betas {
|
||||
if s, ok := b.(string); ok {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
headers["anthropic-beta"] = strings.Join(parts, ",")
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsInterface 检查是否支持接口类型
|
||||
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat,
|
||||
conversion.InterfaceTypeModels,
|
||||
conversion.InterfaceTypeModelInfo:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeRequest 解码请求
|
||||
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return decodeRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRequest 编码请求
|
||||
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeResponse 解码响应
|
||||
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return decodeResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeResponse 编码响应
|
||||
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return encodeResponse(resp)
|
||||
}
|
||||
|
||||
// CreateStreamDecoder 创建流式解码器
|
||||
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
|
||||
return NewStreamDecoder()
|
||||
}
|
||||
|
||||
// CreateStreamEncoder 创建流式编码器
|
||||
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
|
||||
return NewStreamEncoder()
|
||||
}
|
||||
|
||||
// EncodeError 编码错误
|
||||
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
errType := string(err.Code)
|
||||
statusCode := 500
|
||||
|
||||
errMsg := ErrorResponse{
|
||||
Type: "error",
|
||||
Error: ErrorDetail{
|
||||
Type: errType,
|
||||
Message: err.Message,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
// DecodeModelsResponse 解码模型列表响应
|
||||
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return decodeModelsResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelsResponse 编码模型列表响应
|
||||
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return encodeModelsResponse(list)
|
||||
}
|
||||
|
||||
// DecodeModelInfoResponse 解码模型详情响应
|
||||
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return decodeModelInfoResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelInfoResponse 编码模型详情响应
|
||||
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return encodeModelInfoResponse(info)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingRequest Anthropic 不支持嵌入
|
||||
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// EncodeEmbeddingRequest Anthropic 不支持嵌入
|
||||
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// DecodeEmbeddingResponse Anthropic 不支持嵌入
|
||||
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// EncodeEmbeddingResponse Anthropic 不支持嵌入
|
||||
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// DecodeRerankRequest Anthropic 不支持重排序
|
||||
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// EncodeRerankRequest Anthropic 不支持重排序
|
||||
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// DecodeRerankResponse Anthropic 不支持重排序
|
||||
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// EncodeRerankResponse Anthropic 不支持重排序
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
210
backend/internal/conversion/anthropic/adapter_test.go
Normal file
210
backend/internal/conversion/anthropic/adapter_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_ProtocolName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "anthropic", a.ProtocolName())
|
||||
}
|
||||
|
||||
func TestAdapter_ProtocolVersion(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "2023-06-01", a.ProtocolVersion())
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsPassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.True(t, a.SupportsPassthrough())
|
||||
}
|
||||
|
||||
func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天消息", "/v1/messages", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/claude-3", conversion.InterfaceTypeModelInfo},
|
||||
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.DetectInterfaceType(tt.path)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.BuildUrl(tt.nativePath, tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_Basic(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "sk-ant-test", headers["x-api-key"])
|
||||
assert.Equal(t, "2023-06-01", headers["anthropic-version"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_CustomVersion(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
provider.AdapterConfig["anthropic_version"] = "2024-01-01"
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "2024-01-01", headers["anthropic-version"])
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_AnthropicBeta(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
provider.AdapterConfig["anthropic_beta"] = []string{"prompt-caching-2024-07-31", "max-tokens-3-5-sonnet-2024-07-15"}
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15", headers["anthropic-beta"])
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
{"模型详情", conversion.InterfaceTypeModelInfo, true},
|
||||
{"嵌入", conversion.InterfaceTypeEmbeddings, false},
|
||||
{"重排序", conversion.InterfaceTypeRerank, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.SupportsInterface(tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "error", resp.Type)
|
||||
assert.Equal(t, "INVALID_INPUT", resp.Error.Type)
|
||||
assert.Equal(t, "参数无效", resp.Error.Message)
|
||||
}
|
||||
|
||||
func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入请求", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("解码重排序请求", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序请求", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码重排序响应", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序响应", func(t *testing.T) {
|
||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
427
backend/internal/conversion/anthropic/decoder.go
Normal file
427
backend/internal/conversion/anthropic/decoder.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// decodeRequest 将 Anthropic 请求解码为 Canonical 请求
|
||||
func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
var req MessagesRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 请求失败").WithCause(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空")
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空")
|
||||
}
|
||||
|
||||
system := decodeSystem(req.System)
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range req.Messages {
|
||||
decoded := decodeMessage(msg)
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
tools := decodeTools(req.Tools)
|
||||
toolChoice := decodeToolChoice(req.ToolChoice)
|
||||
params := decodeParameters(&req)
|
||||
thinking := decodeThinking(req.Thinking, req.OutputConfig)
|
||||
outputFormat := decodeOutputFormat(req.OutputConfig)
|
||||
|
||||
var parallelToolUse *bool
|
||||
if req.DisableParallelToolUse != nil && *req.DisableParallelToolUse {
|
||||
val := false
|
||||
parallelToolUse = &val
|
||||
}
|
||||
|
||||
var userID string
|
||||
if req.Metadata != nil {
|
||||
userID = req.Metadata.UserID
|
||||
}
|
||||
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: req.Model,
|
||||
System: system,
|
||||
Messages: canonicalMsgs,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
Parameters: params,
|
||||
Thinking: thinking,
|
||||
Stream: req.Stream,
|
||||
UserID: userID,
|
||||
OutputFormat: outputFormat,
|
||||
ParallelToolUse: parallelToolUse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeSystem 解码系统消息
|
||||
func decodeSystem(system any) any {
|
||||
if system == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
case []any:
|
||||
var blocks []canonical.SystemBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok {
|
||||
blocks = append(blocks, canonical.SystemBlock{Text: text})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
return nil
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeMessage 解码 Anthropic 消息
|
||||
func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
var toolResults []canonical.ContentBlock
|
||||
var others []canonical.ContentBlock
|
||||
for _, b := range blocks {
|
||||
if b.Type == "tool_result" {
|
||||
toolResults = append(toolResults, b)
|
||||
} else {
|
||||
others = append(others, b)
|
||||
}
|
||||
}
|
||||
var result []canonical.CanonicalMessage
|
||||
if len(others) > 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: others})
|
||||
}
|
||||
if len(toolResults) > 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleTool, Content: toolResults})
|
||||
}
|
||||
if len(result) == 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||
}
|
||||
return result
|
||||
|
||||
case "assistant":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeContentBlocks 解码内容块列表
|
||||
func decodeContentBlocks(content any) []canonical.ContentBlock {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
block := decodeSingleContentBlock(m)
|
||||
if block != nil {
|
||||
blocks = append(blocks, *block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSingleContentBlock 解码单个内容块
|
||||
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
t, _ := m["type"].(string)
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}
|
||||
case "tool_use":
|
||||
id, _ := m["id"].(string)
|
||||
name, _ := m["name"].(string)
|
||||
input, _ := json.Marshal(m["input"])
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
case "tool_result":
|
||||
toolUseID, _ := m["tool_use_id"].(string)
|
||||
isErr := false
|
||||
if ie, ok := m["is_error"].(bool); ok {
|
||||
isErr = ie
|
||||
}
|
||||
var content json.RawMessage
|
||||
if c, ok := m["content"]; ok {
|
||||
switch cv := c.(type) {
|
||||
case string:
|
||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||
default:
|
||||
content, _ = json.Marshal(cv)
|
||||
}
|
||||
} else {
|
||||
content = json.RawMessage(`""`)
|
||||
}
|
||||
return &canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: content,
|
||||
IsError: &isErr,
|
||||
}
|
||||
case "thinking":
|
||||
thinking, _ := m["thinking"].(string)
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
func decodeTools(tools []Tool) []canonical.CanonicalTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]canonical.CanonicalTool, len(tools))
|
||||
for i, t := range tools {
|
||||
result[i] = canonical.CanonicalTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: t.InputSchema,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeToolChoice 解码工具选择
|
||||
func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
if toolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := toolChoice.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
switch t {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
case "tool":
|
||||
name, _ := v["name"].(string)
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeParameters 解码请求参数
|
||||
func decodeParameters(req *MessagesRequest) canonical.RequestParameters {
|
||||
params := canonical.RequestParameters{
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
val := req.MaxTokens
|
||||
params.MaxTokens = &val
|
||||
}
|
||||
if len(req.StopSequences) > 0 {
|
||||
params.StopSequences = req.StopSequences
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// decodeThinking 解码思考配置
|
||||
func decodeThinking(thinking *ThinkingConfig, outputConfig *OutputConfig) *canonical.ThinkingConfig {
|
||||
if thinking == nil {
|
||||
return nil
|
||||
}
|
||||
cfg := &canonical.ThinkingConfig{
|
||||
Type: thinking.Type,
|
||||
BudgetTokens: thinking.BudgetTokens,
|
||||
}
|
||||
if outputConfig != nil && outputConfig.Effort != "" {
|
||||
cfg.Effort = outputConfig.Effort
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// decodeOutputFormat 解码输出格式
|
||||
func decodeOutputFormat(outputConfig *OutputConfig) *canonical.OutputFormat {
|
||||
if outputConfig == nil || outputConfig.Format == nil {
|
||||
return nil
|
||||
}
|
||||
if outputConfig.Format.Type == "json_schema" && outputConfig.Format.Schema != nil {
|
||||
return &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: "output",
|
||||
Schema: outputConfig.Format.Schema,
|
||||
Strict: boolPtr(true),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeResponse 将 Anthropic 响应解码为 Canonical 响应
|
||||
func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) {
|
||||
var resp MessagesResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 响应失败").WithCause(err)
|
||||
}
|
||||
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
blocks = append(blocks, canonical.NewTextBlock(block.Text))
|
||||
case "tool_use":
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(block.ID, block.Name, block.Input))
|
||||
case "thinking":
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(block.Thinking))
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
|
||||
sr := mapStopReason(resp.StopReason)
|
||||
usage := canonical.CanonicalUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.CacheReadInputTokens != nil {
|
||||
usage.CacheReadTokens = resp.Usage.CacheReadInputTokens
|
||||
}
|
||||
if resp.Usage.CacheCreationInputTokens != nil {
|
||||
usage.CacheCreationTokens = resp.Usage.CacheCreationInputTokens
|
||||
}
|
||||
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: blocks,
|
||||
StopReason: &sr,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mapStopReason 映射停止原因
|
||||
func mapStopReason(reason string) canonical.StopReason {
|
||||
switch reason {
|
||||
case "end_turn":
|
||||
return canonical.StopReasonEndTurn
|
||||
case "max_tokens":
|
||||
return canonical.StopReasonMaxTokens
|
||||
case "tool_use":
|
||||
return canonical.StopReasonToolUse
|
||||
case "stop_sequence":
|
||||
return canonical.StopReasonStopSequence
|
||||
case "pause_turn":
|
||||
return canonical.StopReason("pause_turn")
|
||||
case "refusal":
|
||||
return canonical.StopReasonRefusal
|
||||
default:
|
||||
return canonical.StopReasonEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
// decodeModelsResponse 解码模型列表响应
|
||||
func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) {
|
||||
var resp ModelsResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
models := make([]canonical.CanonicalModel, len(resp.Data))
|
||||
for i, m := range resp.Data {
|
||||
name := m.DisplayName
|
||||
if name == "" {
|
||||
name = m.ID
|
||||
}
|
||||
models[i] = canonical.CanonicalModel{
|
||||
ID: m.ID,
|
||||
Name: name,
|
||||
Created: parseTimestamp(m.CreatedAt),
|
||||
OwnedBy: "anthropic",
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalModelList{Models: models}, nil
|
||||
}
|
||||
|
||||
// decodeModelInfoResponse 解码模型详情响应
|
||||
func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
var resp ModelInfoResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := resp.DisplayName
|
||||
if name == "" {
|
||||
name = resp.ID
|
||||
}
|
||||
return &canonical.CanonicalModelInfo{
|
||||
ID: resp.ID,
|
||||
Name: name,
|
||||
Created: parseTimestamp(resp.CreatedAt),
|
||||
OwnedBy: "anthropic",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseTimestamp 解析 RFC 3339 时间戳为 Unix
|
||||
func parseTimestamp(s string) int64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// formatTimestamp 将 Unix 时间戳格式化为 RFC 3339
|
||||
func formatTimestamp(unix int64) string {
|
||||
if unix == 0 {
|
||||
return time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339)
|
||||
}
|
||||
return time.Unix(unix, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// boolPtr 返回 bool 指针
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
331
backend/internal/conversion/anthropic/decoder_test.go
Normal file
331
backend/internal/conversion/anthropic/decoder_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeRequest_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", req.Model)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.NotNil(t, req.Parameters.MaxTokens)
|
||||
assert.Equal(t, 1024, *req.Parameters.MaxTokens)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_System(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"system": "你是助手",
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "你是助手", req.System)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_SystemBlocks(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"system": [{"text": "指令1"}, {"text": "指令2"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
blocks, ok := req.System.([]canonical.SystemBlock)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, blocks, 2)
|
||||
assert.Equal(t, "指令1", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolResultSplit(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "查询天气"},
|
||||
{"type": "tool_result", "tool_use_id": "tool_1", "content": "晴天"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
// 用户消息中的 tool_result 应被拆分为独立的 tool 消息
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.Equal(t, canonical.RoleTool, req.Messages[1].Role)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingModel(t *testing.T) {
|
||||
body := []byte(`{"max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}]}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingMessages(t *testing.T) {
|
||||
body := []byte(`{"model": "claude-3", "max_tokens": 1024}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "你好"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "msg_123", resp.ID)
|
||||
assert.Equal(t, "claude-3", resp.Model)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "你好", resp.Content[0].Text)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason)
|
||||
assert.Equal(t, 10, resp.Usage.InputTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_456",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "思考过程"},
|
||||
{"type": "text", "text": "回答"}
|
||||
],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 2)
|
||||
assert.Equal(t, "thinking", resp.Content[0].Type)
|
||||
assert.Equal(t, "思考过程", resp.Content[0].Thinking)
|
||||
assert.Equal(t, "text", resp.Content[1].Type)
|
||||
assert.Equal(t, "回答", resp.Content[1].Text)
|
||||
}
|
||||
|
||||
func TestDecodeModelsResponse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"data": [
|
||||
{"id": "claude-3-opus", "type": "model", "display_name": "Claude 3 Opus", "created_at": "2024-01-15T00:00:00Z"},
|
||||
{"id": "claude-3-sonnet", "type": "model", "created_at": "2024-02-01T00:00:00Z"}
|
||||
],
|
||||
"has_more": false
|
||||
}`)
|
||||
|
||||
list, err := decodeModelsResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list.Models, 2)
|
||||
assert.Equal(t, "claude-3-opus", list.Models[0].ID)
|
||||
assert.Equal(t, "Claude 3 Opus", list.Models[0].Name)
|
||||
// created_at RFC3339 → Unix
|
||||
assert.NotEqual(t, int64(0), list.Models[0].Created)
|
||||
// 无 display_name 时使用 ID
|
||||
assert.Equal(t, "claude-3-sonnet", list.Models[1].Name)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRequest([]byte(`invalid json`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "JSON_PARSE_ERROR")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 5000}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "enabled", req.Thinking.Type)
|
||||
require.NotNil(t, req.Thinking.BudgetTokens)
|
||||
assert.Equal(t, 5000, *req.Thinking.BudgetTokens)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ThinkingAdaptive(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"thinking": {"type": "adaptive"}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "adaptive", req.Thinking.Type)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputConfig(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"output_config": {
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_schema", req.OutputFormat.Type)
|
||||
assert.NotNil(t, req.OutputFormat.Schema)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_DisableParallelToolUse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"disable_parallel_tool_use": true
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.ParallelToolUse)
|
||||
assert.False(t, *req.ParallelToolUse)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_ToolUse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_tool",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "tool_1", "name": "search", "input": {"q": "test"}}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "tool_use", resp.Content[0].Type)
|
||||
assert.Equal(t, "tool_1", resp.Content[0].ID)
|
||||
assert.Equal(t, "search", resp.Content[0].Name)
|
||||
assert.NotNil(t, resp.Content[0].Input)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_RedactedThinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_redacted",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "redacted_thinking", "data": "..."},
|
||||
{"type": "text", "text": "回答"}
|
||||
],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "text", resp.Content[0].Type)
|
||||
assert.Equal(t, "回答", resp.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason string
|
||||
want canonical.StopReason
|
||||
}{
|
||||
{"end_turn→end_turn", "end_turn", canonical.StopReasonEndTurn},
|
||||
{"max_tokens→max_tokens", "max_tokens", canonical.StopReasonMaxTokens},
|
||||
{"tool_use→tool_use", "tool_use", canonical.StopReasonToolUse},
|
||||
{"stop_sequence→stop_sequence", "stop_sequence", canonical.StopReasonStopSequence},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"id": "msg-1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "%s",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1}
|
||||
}`, tt.reason))
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, tt.want, *resp.StopReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Usage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_usage",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 50, resp.Usage.OutputTokens)
|
||||
require.NotNil(t, resp.Usage.CacheReadTokens)
|
||||
assert.Equal(t, 30, *resp.Usage.CacheReadTokens)
|
||||
}
|
||||
449
backend/internal/conversion/anthropic/encoder.go
Normal file
449
backend/internal/conversion/anthropic/encoder.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// encodeRequest 将 Canonical 请求编码为 Anthropic 请求
|
||||
func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"stream": req.Stream,
|
||||
}
|
||||
|
||||
// max_tokens 必填
|
||||
if req.Parameters.MaxTokens != nil {
|
||||
result["max_tokens"] = *req.Parameters.MaxTokens
|
||||
} else {
|
||||
result["max_tokens"] = 4096
|
||||
}
|
||||
|
||||
// 系统消息
|
||||
if req.System != nil {
|
||||
result["system"] = encodeSystem(req.System)
|
||||
}
|
||||
|
||||
// 消息
|
||||
result["messages"] = encodeMessages(req.Messages)
|
||||
|
||||
// 参数
|
||||
if req.Parameters.Temperature != nil {
|
||||
result["temperature"] = *req.Parameters.Temperature
|
||||
}
|
||||
if req.Parameters.TopP != nil {
|
||||
result["top_p"] = *req.Parameters.TopP
|
||||
}
|
||||
if req.Parameters.TopK != nil {
|
||||
result["top_k"] = *req.Parameters.TopK
|
||||
}
|
||||
if len(req.Parameters.StopSequences) > 0 {
|
||||
result["stop_sequences"] = req.Parameters.StopSequences
|
||||
}
|
||||
|
||||
// 工具
|
||||
if len(req.Tools) > 0 {
|
||||
tools := make([]map[string]any, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
tool := map[string]any{
|
||||
"name": t.Name,
|
||||
"input_schema": t.InputSchema,
|
||||
}
|
||||
if t.Description != "" {
|
||||
tool["description"] = t.Description
|
||||
}
|
||||
tools[i] = tool
|
||||
}
|
||||
result["tools"] = tools
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
result["tool_choice"] = encodeToolChoice(req.ToolChoice)
|
||||
}
|
||||
|
||||
// 公共字段
|
||||
if req.UserID != "" {
|
||||
result["metadata"] = map[string]any{"user_id": req.UserID}
|
||||
}
|
||||
if req.ParallelToolUse != nil && !*req.ParallelToolUse {
|
||||
result["disable_parallel_tool_use"] = true
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
result["thinking"] = encodeThinkingConfig(req.Thinking)
|
||||
}
|
||||
|
||||
// output_config
|
||||
outputConfig := map[string]any{}
|
||||
hasOutputConfig := false
|
||||
if req.OutputFormat != nil {
|
||||
of := encodeOutputFormat(req.OutputFormat)
|
||||
if of != nil {
|
||||
outputConfig["format"] = of
|
||||
hasOutputConfig = true
|
||||
}
|
||||
}
|
||||
if req.Thinking != nil && req.Thinking.Effort != "" {
|
||||
outputConfig["effort"] = req.Thinking.Effort
|
||||
hasOutputConfig = true
|
||||
}
|
||||
if hasOutputConfig {
|
||||
result["output_config"] = outputConfig
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 请求失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// encodeSystem 编码系统消息
|
||||
func encodeSystem(system any) any {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []canonical.SystemBlock:
|
||||
blocks := make([]map[string]any, len(v))
|
||||
for i, b := range v {
|
||||
blocks[i] = map[string]any{"text": b.Text}
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// encodeMessages 编码消息列表(含角色约束处理)
|
||||
func encodeMessages(msgs []canonical.CanonicalMessage) []map[string]any {
|
||||
var result []map[string]any
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case canonical.RoleUser:
|
||||
result = append(result, map[string]any{
|
||||
"role": "user",
|
||||
"content": encodeContentBlocks(msg.Content),
|
||||
})
|
||||
case canonical.RoleAssistant:
|
||||
result = append(result, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": encodeContentBlocks(msg.Content),
|
||||
})
|
||||
case canonical.RoleTool:
|
||||
// tool 角色合并到相邻 user 消息
|
||||
toolResults := filterToolResults(msg.Content)
|
||||
if len(result) > 0 && result[len(result)-1]["role"] == "user" {
|
||||
// 合并到最后一条 user 消息
|
||||
lastContent, ok := result[len(result)-1]["content"].([]map[string]any)
|
||||
if ok {
|
||||
result[len(result)-1]["content"] = append(lastContent, toolResults...)
|
||||
} else {
|
||||
result[len(result)-1]["content"] = toolResults
|
||||
}
|
||||
} else {
|
||||
result = append(result, map[string]any{
|
||||
"role": "user",
|
||||
"content": toolResults,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确保首消息为 user
|
||||
if len(result) > 0 && result[0]["role"] != "user" {
|
||||
result = append([]map[string]any{{"role": "user", "content": []map[string]any{}}}, result...)
|
||||
}
|
||||
|
||||
// 合并连续同角色消息
|
||||
result = mergeConsecutiveRoles(result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeContentBlocks 编码内容块列表
|
||||
func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
|
||||
result := make([]map[string]any, 0, len(blocks))
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
result = append(result, map[string]any{"type": "text", "text": b.Text})
|
||||
case "tool_use":
|
||||
m := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": b.ID,
|
||||
"name": b.Name,
|
||||
"input": b.Input,
|
||||
}
|
||||
if b.Input == nil {
|
||||
m["input"] = map[string]any{}
|
||||
}
|
||||
result = append(result, m)
|
||||
case "tool_result":
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
var contentStr string
|
||||
if json.Unmarshal(b.Content, &contentStr) == nil {
|
||||
m["content"] = contentStr
|
||||
} else {
|
||||
m["content"] = string(b.Content)
|
||||
}
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
if b.IsError != nil {
|
||||
m["is_error"] = *b.IsError
|
||||
}
|
||||
result = append(result, m)
|
||||
case "thinking":
|
||||
result = append(result, map[string]any{"type": "thinking", "thinking": b.Thinking})
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return []map[string]any{{"type": "text", "text": ""}}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// filterToolResults 过滤工具结果
|
||||
func filterToolResults(blocks []canonical.ContentBlock) []map[string]any {
|
||||
var result []map[string]any
|
||||
for _, b := range blocks {
|
||||
if b.Type == "tool_result" {
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
var contentStr string
|
||||
if json.Unmarshal(b.Content, &contentStr) == nil {
|
||||
m["content"] = contentStr
|
||||
} else {
|
||||
m["content"] = string(b.Content)
|
||||
}
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
if b.IsError != nil {
|
||||
m["is_error"] = *b.IsError
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeToolChoice 编码工具选择
|
||||
func encodeToolChoice(choice *canonical.ToolChoice) any {
|
||||
switch choice.Type {
|
||||
case "auto":
|
||||
return map[string]any{"type": "auto"}
|
||||
case "none":
|
||||
return map[string]any{"type": "none"}
|
||||
case "any":
|
||||
return map[string]any{"type": "any"}
|
||||
case "tool":
|
||||
return map[string]any{"type": "tool", "name": choice.Name}
|
||||
}
|
||||
return map[string]any{"type": "auto"}
|
||||
}
|
||||
|
||||
// encodeThinkingConfig 编码思考配置
|
||||
func encodeThinkingConfig(cfg *canonical.ThinkingConfig) map[string]any {
|
||||
switch cfg.Type {
|
||||
case "enabled":
|
||||
m := map[string]any{"type": "enabled"}
|
||||
if cfg.BudgetTokens != nil {
|
||||
m["budget_tokens"] = *cfg.BudgetTokens
|
||||
}
|
||||
return m
|
||||
case "disabled":
|
||||
return map[string]any{"type": "disabled"}
|
||||
case "adaptive":
|
||||
return map[string]any{"type": "adaptive"}
|
||||
}
|
||||
return map[string]any{"type": "disabled"}
|
||||
}
|
||||
|
||||
// encodeOutputFormat 编码输出格式
|
||||
func encodeOutputFormat(format *canonical.OutputFormat) map[string]any {
|
||||
if format == nil {
|
||||
return nil
|
||||
}
|
||||
switch format.Type {
|
||||
case "json_schema":
|
||||
schema := format.Schema
|
||||
if schema == nil {
|
||||
schema = json.RawMessage(`{"type":"object"}`)
|
||||
}
|
||||
return map[string]any{
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
case "json_object":
|
||||
return map[string]any{
|
||||
"type": "json_schema",
|
||||
"schema": map[string]any{"type": "object"},
|
||||
}
|
||||
case "text":
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeResponse 将 Canonical 响应编码为 Anthropic 响应
|
||||
func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
blocks := make([]map[string]any, 0, len(resp.Content))
|
||||
for _, b := range resp.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
blocks = append(blocks, map[string]any{"type": "text", "text": b.Text})
|
||||
case "tool_use":
|
||||
m := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": b.ID,
|
||||
"name": b.Name,
|
||||
"input": b.Input,
|
||||
}
|
||||
if b.Input == nil {
|
||||
m["input"] = map[string]any{}
|
||||
}
|
||||
blocks = append(blocks, m)
|
||||
case "thinking":
|
||||
blocks = append(blocks, map[string]any{"type": "thinking", "thinking": b.Thinking})
|
||||
}
|
||||
}
|
||||
|
||||
sr := "end_turn"
|
||||
if resp.StopReason != nil {
|
||||
sr = mapCanonicalStopReason(*resp.StopReason)
|
||||
}
|
||||
|
||||
usage := map[string]any{
|
||||
"input_tokens": resp.Usage.InputTokens,
|
||||
"output_tokens": resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.CacheReadTokens != nil {
|
||||
usage["cache_read_input_tokens"] = *resp.Usage.CacheReadTokens
|
||||
}
|
||||
if resp.Usage.CacheCreationTokens != nil {
|
||||
usage["cache_creation_input_tokens"] = *resp.Usage.CacheCreationTokens
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"stop_reason": sr,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 响应失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapCanonicalStopReason 映射 Canonical 停止原因到 Anthropic
|
||||
func mapCanonicalStopReason(reason canonical.StopReason) string {
|
||||
switch reason {
|
||||
case canonical.StopReasonEndTurn, canonical.StopReasonContentFilter:
|
||||
return "end_turn"
|
||||
case canonical.StopReasonMaxTokens:
|
||||
return "max_tokens"
|
||||
case canonical.StopReasonToolUse:
|
||||
return "tool_use"
|
||||
case canonical.StopReasonStopSequence:
|
||||
return "stop_sequence"
|
||||
case canonical.StopReasonRefusal:
|
||||
return "refusal"
|
||||
default:
|
||||
return "end_turn"
|
||||
}
|
||||
}
|
||||
|
||||
// encodeModelsResponse 编码模型列表响应
|
||||
func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
data := make([]map[string]any, len(list.Models))
|
||||
for i, m := range list.Models {
|
||||
name := m.Name
|
||||
if name == "" {
|
||||
name = m.ID
|
||||
}
|
||||
data[i] = map[string]any{
|
||||
"id": m.ID,
|
||||
"type": "model",
|
||||
"display_name": name,
|
||||
"created_at": formatTimestamp(m.Created),
|
||||
}
|
||||
}
|
||||
|
||||
var firstID, lastID *string
|
||||
if len(list.Models) > 0 {
|
||||
fid := list.Models[0].ID
|
||||
firstID = &fid
|
||||
lid := list.Models[len(list.Models)-1].ID
|
||||
lastID = &lid
|
||||
}
|
||||
|
||||
return json.Marshal(map[string]any{
|
||||
"data": data,
|
||||
"has_more": false,
|
||||
"first_id": firstID,
|
||||
"last_id": lastID,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeModelInfoResponse 编码模型详情响应
|
||||
func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
name := info.Name
|
||||
if name == "" {
|
||||
name = info.ID
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"id": info.ID,
|
||||
"type": "model",
|
||||
"display_name": name,
|
||||
"created_at": formatTimestamp(info.Created),
|
||||
})
|
||||
}
|
||||
|
||||
// mergeConsecutiveRoles 合并连续同角色消息
|
||||
func mergeConsecutiveRoles(messages []map[string]any) []map[string]any {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
var result []map[string]any
|
||||
for _, msg := range messages {
|
||||
if len(result) > 0 {
|
||||
lastRole := result[len(result)-1]["role"]
|
||||
currRole := msg["role"]
|
||||
if lastRole == currRole {
|
||||
// 合并 content
|
||||
lastContent := result[len(result)-1]["content"]
|
||||
currContent := msg["content"]
|
||||
switch lv := lastContent.(type) {
|
||||
case []map[string]any:
|
||||
if cv, ok := currContent.([]map[string]any); ok {
|
||||
result[len(result)-1]["content"] = append(lv, cv...)
|
||||
}
|
||||
case string:
|
||||
if cv, ok := currContent.(string); ok {
|
||||
result[len(result)-1]["content"] = lv + cv
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
350
backend/internal/conversion/anthropic/encoder_test.go
Normal file
350
backend/internal/conversion/anthropic/encoder_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeRequest_Basic(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Stream: true,
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-model", result["model"])
|
||||
assert.Equal(t, true, result["stream"])
|
||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||
|
||||
msgs := result["messages"].([]any)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("查询")}},
|
||||
{Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", json.RawMessage(`{"q":"test"}`))}},
|
||||
{Role: canonical.RoleTool, Content: []canonical.ContentBlock{canonical.NewToolResultBlock("tool_1", "结果", false)}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
|
||||
// tool 消息应被合并到相邻 user 消息
|
||||
foundToolResult := false
|
||||
for _, m := range msgs {
|
||||
msgMap := m.(map[string]any)
|
||||
if msgMap["role"] == "user" {
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if ok {
|
||||
for _, c := range content {
|
||||
block := c.(map[string]any)
|
||||
if block["type"] == "tool_result" {
|
||||
foundToolResult = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, foundToolResult)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewTextBlock("前置")}},
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
assert.Equal(t, "user", firstMsg["role"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingEnabled(t *testing.T) {
|
||||
budget := 10000
|
||||
maxTokens := 8096
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "enabled", BudgetTokens: &budget},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
thinking, ok := result["thinking"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "enabled", thinking["type"])
|
||||
assert.Equal(t, float64(10000), thinking["budget_tokens"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg_1",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "msg_1", result["id"])
|
||||
assert.Equal(t, "message", result["type"])
|
||||
assert.Equal(t, "assistant", result["role"])
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
|
||||
content := result["content"].([]any)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
assert.Equal(t, "text", block["type"])
|
||||
assert.Equal(t, "你好", block["text"])
|
||||
}
|
||||
|
||||
func TestEncodeModelsResponse(t *testing.T) {
|
||||
ts := time.Date(2024, 3, 15, 0, 0, 0, 0, time.UTC).Unix()
|
||||
list := &canonical.CanonicalModelList{
|
||||
Models: []canonical.CanonicalModel{
|
||||
{ID: "claude-3-opus", Name: "Claude 3 Opus", Created: ts, OwnedBy: "anthropic"},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeModelsResponse(list)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
data := result["data"].([]any)
|
||||
assert.Len(t, data, 1)
|
||||
|
||||
model := data[0].(map[string]any)
|
||||
assert.Equal(t, "claude-3-opus", model["id"])
|
||||
// created 应为 RFC3339 格式
|
||||
createdAt, ok := model["created_at"].(string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, createdAt, "2024")
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingDisabled(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
_, hasThinking := result["thinking"]
|
||||
assert.False(t, hasThinking)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingAdaptive(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "adaptive"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
thinking, ok := result["thinking"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "adaptive", thinking["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Schema: schema,
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
oc, ok := result["output_config"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
format, ok := oc["format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", format["type"])
|
||||
assert.NotNil(t, format["schema"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSON(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_object",
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
oc, ok := result["output_config"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
format, ok := oc["format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", format["type"])
|
||||
schemaMap, ok := format["schema"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "object", schemaMap["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("A")}},
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("B")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assert.Len(t, msgs, 1)
|
||||
userMsg := msgs[0].(map[string]any)
|
||||
assert.Equal(t, "user", userMsg["role"])
|
||||
content := userMsg["content"].([]any)
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ContentFilter(t *testing.T) {
|
||||
sr := canonical.StopReasonContentFilter
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-cf",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
||||
reasoning := 100
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-rt",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5, ReasoningTokens: &reasoning},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
_, hasReasoning := usage["reasoning_tokens"]
|
||||
assert.False(t, hasReasoning)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-tool",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
content := result["content"].([]any)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
assert.Equal(t, "tool_use", block["type"])
|
||||
assert.Equal(t, "tool_1", block["id"])
|
||||
assert.Equal(t, "search", block["name"])
|
||||
}
|
||||
283
backend/internal/conversion/anthropic/stream_decoder.go
Normal file
283
backend/internal/conversion/anthropic/stream_decoder.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder Anthropic 流式解码器
|
||||
type StreamDecoder struct {
|
||||
messageStarted bool
|
||||
redactedBlocks map[int]bool
|
||||
utf8Remainder []byte
|
||||
accumulatedUsage *canonical.CanonicalUsage
|
||||
}
|
||||
|
||||
// NewStreamDecoder 创建 Anthropic 流式解码器
|
||||
func NewStreamDecoder() *StreamDecoder {
|
||||
return &StreamDecoder{
|
||||
redactedBlocks: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 处理原始 SSE chunk
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
if !utf8.Valid(data) {
|
||||
validEnd := len(data)
|
||||
for !utf8.Valid(data[:validEnd]) {
|
||||
validEnd--
|
||||
}
|
||||
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
|
||||
data = data[:validEnd]
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
text := string(data)
|
||||
|
||||
// 解析命名 SSE 事件
|
||||
var eventType string
|
||||
var eventData string
|
||||
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
eventType = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "data: ") {
|
||||
eventData = strings.TrimPrefix(line, "data: ")
|
||||
if eventType != "" && eventData != "" {
|
||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||
events = append(events, chunkEvents...)
|
||||
}
|
||||
eventType = ""
|
||||
eventData = ""
|
||||
} else if line == "" {
|
||||
// SSE 事件分隔符
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 刷新解码器状态
|
||||
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processEvent 处理单个命名 SSE 事件
|
||||
func (d *StreamDecoder) processEvent(eventType string, data []byte) []canonical.CanonicalStreamEvent {
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
return d.processMessageStart(data)
|
||||
case "content_block_start":
|
||||
return d.processContentBlockStart(data)
|
||||
case "content_block_delta":
|
||||
return d.processContentBlockDelta(data)
|
||||
case "content_block_stop":
|
||||
return d.processContentBlockStop(data)
|
||||
case "message_delta":
|
||||
return d.processMessageDelta(data)
|
||||
case "message_stop":
|
||||
return d.processMessageStop(data)
|
||||
case "ping":
|
||||
return []canonical.CanonicalStreamEvent{canonical.NewPingEvent()}
|
||||
case "error":
|
||||
return d.processError(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processMessageStart 处理消息开始事件
|
||||
func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage *struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if msgRaw, ok := raw["message"]; ok {
|
||||
if err := json.Unmarshal(msgRaw, &msg); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
event := canonical.NewMessageStartEvent(msg.ID, msg.Model)
|
||||
if msg.Usage != nil {
|
||||
usage := &canonical.CanonicalUsage{
|
||||
InputTokens: msg.Usage.InputTokens,
|
||||
OutputTokens: msg.Usage.OutputTokens,
|
||||
}
|
||||
event = canonical.NewMessageStartEventWithUsage(msg.ID, msg.Model, usage)
|
||||
d.accumulatedUsage = usage
|
||||
}
|
||||
|
||||
d.messageStarted = true
|
||||
return []canonical.CanonicalStreamEvent{event}
|
||||
}
|
||||
|
||||
// processContentBlockStart 处理内容块开始事件
|
||||
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Thinking string `json:"thinking"`
|
||||
Data string `json:"data"`
|
||||
} `json:"content_block"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查需要丢弃的块类型
|
||||
switch raw.ContentBlock.Type {
|
||||
case "redacted_thinking", "server_tool_use", "web_search_tool_result",
|
||||
"code_execution_tool_result":
|
||||
d.redactedBlocks[raw.Index] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.redactedBlocks[raw.Index] {
|
||||
return nil
|
||||
}
|
||||
|
||||
block := canonical.StreamContentBlock{
|
||||
Type: raw.ContentBlock.Type,
|
||||
Text: raw.ContentBlock.Text,
|
||||
ID: raw.ContentBlock.ID,
|
||||
Name: raw.ContentBlock.Name,
|
||||
Thinking: raw.ContentBlock.Thinking,
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockStartEvent(raw.Index, block),
|
||||
}
|
||||
}
|
||||
|
||||
// processContentBlockDelta 处理内容块增量事件
|
||||
func (d *StreamDecoder) processContentBlockDelta(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
PartialJSON string `json:"partial_json"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"delta"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否在丢弃的块中
|
||||
if d.redactedBlocks[raw.Index] {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 丢弃协议特有 delta 类型
|
||||
switch raw.Delta.Type {
|
||||
case "citations_delta", "signature_delta":
|
||||
return nil
|
||||
}
|
||||
|
||||
delta := canonical.StreamDelta{
|
||||
Type: raw.Delta.Type,
|
||||
Text: raw.Delta.Text,
|
||||
PartialJSON: raw.Delta.PartialJSON,
|
||||
Thinking: raw.Delta.Thinking,
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockDeltaEvent(raw.Index, delta),
|
||||
}
|
||||
}
|
||||
|
||||
// processContentBlockStop 处理内容块结束事件
|
||||
func (d *StreamDecoder) processContentBlockStop(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, redacted := d.redactedBlocks[raw.Index]; redacted {
|
||||
delete(d.redactedBlocks, raw.Index)
|
||||
return nil
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockStopEvent(raw.Index),
|
||||
}
|
||||
}
|
||||
|
||||
// processMessageDelta 处理消息增量事件
|
||||
func (d *StreamDecoder) processMessageDelta(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Delta struct {
|
||||
StopReason string `json:"stop_reason"`
|
||||
} `json:"delta"`
|
||||
Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sr := mapStopReason(raw.Delta.StopReason)
|
||||
usage := &canonical.CanonicalUsage{
|
||||
OutputTokens: raw.Usage.OutputTokens,
|
||||
}
|
||||
|
||||
if d.accumulatedUsage != nil {
|
||||
d.accumulatedUsage.OutputTokens += raw.Usage.OutputTokens
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageDeltaEventWithUsage(sr, usage),
|
||||
}
|
||||
}
|
||||
|
||||
// processMessageStop 处理消息结束事件
|
||||
func (d *StreamDecoder) processMessageStop(data []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{canonical.NewMessageStopEvent()}
|
||||
}
|
||||
|
||||
// processError 处理错误事件
|
||||
func (d *StreamDecoder) processError(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewErrorEvent("stream_error", fmt.Sprintf("解析错误事件失败: %s", string(data))),
|
||||
}
|
||||
}
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewErrorEvent(raw.Error.Type, raw.Error.Message),
|
||||
}
|
||||
}
|
||||
274
backend/internal/conversion/anthropic/stream_decoder_test.go
Normal file
274
backend/internal/conversion/anthropic/stream_decoder_test.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeAnthropicEvent(eventType string, data any) []byte {
|
||||
dataBytes, _ := json.Marshal(data)
|
||||
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(dataBytes)))
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageStart(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_1",
|
||||
"model": "claude-3",
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("message_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
assert.Equal(t, canonical.EventMessageStart, events[0].Type)
|
||||
assert.Equal(t, "msg_1", events[0].Message.ID)
|
||||
assert.Equal(t, "claude-3", events[0].Message.Model)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deltaType string
|
||||
deltaData map[string]any
|
||||
checkField string
|
||||
checkValue string
|
||||
}{
|
||||
{
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
checkField: "text",
|
||||
checkValue: "你好",
|
||||
},
|
||||
{
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
checkField: "partial_json",
|
||||
checkValue: "{\"key\":",
|
||||
},
|
||||
{
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
checkField: "thinking",
|
||||
checkValue: "思考中",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": tt.deltaData,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
assert.Equal(t, canonical.EventContentBlockDelta, events[0].Type)
|
||||
assert.NotNil(t, events[0].Delta)
|
||||
|
||||
switch tt.checkField {
|
||||
case "text":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.Text)
|
||||
case "partial_json":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.PartialJSON)
|
||||
case "thinking":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.Thinking)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedThinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// redacted_thinking block start 应被抑制
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]any{
|
||||
"type": "redacted_thinking",
|
||||
"data": "redacted-data",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
assert.True(t, d.redactedBlocks[1])
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedBlockStopSuppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
d.redactedBlocks[2] = true
|
||||
|
||||
// content_block_stop 对 redacted block 返回 nil
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": 2,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_stop", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
// 应清理 redactedBlocks
|
||||
_, exists := d.redactedBlocks[2]
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStart(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
|
||||
require.NotNil(t, events[0].ContentBlock)
|
||||
assert.Equal(t, "text", events[0].ContentBlock.Type)
|
||||
require.NotNil(t, events[0].Index)
|
||||
assert.Equal(t, 0, *events[0].Index)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_1",
|
||||
"name": "search",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
|
||||
require.NotNil(t, events[0].ContentBlock)
|
||||
assert.Equal(t, "tool_use", events[0].ContentBlock.Type)
|
||||
assert.Equal(t, "toolu_1", events[0].ContentBlock.ID)
|
||||
assert.Equal(t, "search", events[0].ContentBlock.Name)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStop(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_stop", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStop, events[0].Type)
|
||||
require.NotNil(t, events[0].Index)
|
||||
assert.Equal(t, 0, *events[0].Index)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageDelta(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": 42,
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("message_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventMessageDelta, events[0].Type)
|
||||
require.NotNil(t, events[0].StopReason)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *events[0].StopReason)
|
||||
require.NotNil(t, events[0].Usage)
|
||||
assert.Equal(t, 42, events[0].Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageStop(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := makeAnthropicEvent("message_stop", map[string]any{"type": "message_stop"})
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventMessageStop, events[0].Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Ping(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := makeAnthropicEvent("ping", map[string]any{"type": "ping"})
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventPing, events[0].Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Error(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "overloaded_error",
|
||||
"message": "服务过载",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("error", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventError, events[0].Type)
|
||||
require.NotNil(t, events[0].Error)
|
||||
assert.Equal(t, "overloaded_error", events[0].Error.Type)
|
||||
assert.Equal(t, "服务过载", events[0].Error.Message)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedDeltaSuppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
d.redactedBlocks[1] = true
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 1,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": "被抑制的内容",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
188
backend/internal/conversion/anthropic/stream_encoder.go
Normal file
188
backend/internal/conversion/anthropic/stream_encoder.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamEncoder Anthropic 流式编码器
|
||||
type StreamEncoder struct{}
|
||||
|
||||
// NewStreamEncoder 创建 Anthropic 流式编码器
|
||||
func NewStreamEncoder() *StreamEncoder {
|
||||
return &StreamEncoder{}
|
||||
}
|
||||
|
||||
// EncodeEvent 编码 Canonical 事件为 Anthropic 命名 SSE 事件
|
||||
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
switch event.Type {
|
||||
case canonical.EventMessageStart:
|
||||
return e.encodeMessageStart(event)
|
||||
case canonical.EventContentBlockStart:
|
||||
return e.encodeContentBlockStart(event)
|
||||
case canonical.EventContentBlockDelta:
|
||||
return e.encodeContentBlockDelta(event)
|
||||
case canonical.EventContentBlockStop:
|
||||
return e.encodeContentBlockStop(event)
|
||||
case canonical.EventMessageDelta:
|
||||
return e.encodeMessageDelta(event)
|
||||
case canonical.EventMessageStop:
|
||||
return e.encodeMessageStop(event)
|
||||
case canonical.EventPing:
|
||||
return e.encodePing()
|
||||
case canonical.EventError:
|
||||
return e.encodeError(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区(无缓冲)
|
||||
func (e *StreamEncoder) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeMessageStart 编码消息开始事件
|
||||
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
payload := map[string]any{
|
||||
"type": "message_start",
|
||||
}
|
||||
if event.Message != nil {
|
||||
msg := map[string]any{
|
||||
"id": event.Message.ID,
|
||||
"model": event.Message.Model,
|
||||
"role": "assistant",
|
||||
}
|
||||
if event.Message.Usage != nil {
|
||||
usage := map[string]any{
|
||||
"input_tokens": event.Message.Usage.InputTokens,
|
||||
"output_tokens": event.Message.Usage.OutputTokens,
|
||||
}
|
||||
msg["usage"] = usage
|
||||
}
|
||||
payload["message"] = msg
|
||||
}
|
||||
return e.marshalEvent("message_start", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockStart 编码内容块开始事件
|
||||
func (e *StreamEncoder) encodeContentBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.ContentBlock == nil || event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cb := map[string]any{
|
||||
"type": event.ContentBlock.Type,
|
||||
}
|
||||
switch event.ContentBlock.Type {
|
||||
case "text":
|
||||
cb["text"] = ""
|
||||
case "tool_use":
|
||||
cb["id"] = event.ContentBlock.ID
|
||||
cb["name"] = event.ContentBlock.Name
|
||||
cb["input"] = map[string]any{}
|
||||
case "thinking":
|
||||
cb["thinking"] = ""
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": *event.Index,
|
||||
"content_block": cb,
|
||||
}
|
||||
return e.marshalEvent("content_block_start", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockDelta 编码内容块增量事件
|
||||
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Delta == nil || event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
delta := map[string]any{
|
||||
"type": event.Delta.Type,
|
||||
}
|
||||
switch canonical.DeltaType(event.Delta.Type) {
|
||||
case canonical.DeltaTypeText:
|
||||
delta["text"] = event.Delta.Text
|
||||
case canonical.DeltaTypeInputJSON:
|
||||
delta["partial_json"] = event.Delta.PartialJSON
|
||||
case canonical.DeltaTypeThinking:
|
||||
delta["thinking"] = event.Delta.Thinking
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": *event.Index,
|
||||
"delta": delta,
|
||||
}
|
||||
return e.marshalEvent("content_block_delta", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockStop 编码内容块结束事件
|
||||
func (e *StreamEncoder) encodeContentBlockStop(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": *event.Index,
|
||||
}
|
||||
return e.marshalEvent("content_block_stop", payload)
|
||||
}
|
||||
|
||||
// encodeMessageDelta 编码消息增量事件
|
||||
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{}
|
||||
if event.StopReason != nil {
|
||||
delta["stop_reason"] = mapCanonicalStopReason(*event.StopReason)
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": delta,
|
||||
}
|
||||
if event.Usage != nil {
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": event.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return e.marshalEvent("message_delta", payload)
|
||||
}
|
||||
|
||||
// encodeMessageStop 编码消息结束事件
|
||||
func (e *StreamEncoder) encodeMessageStop(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
payload := map[string]any{"type": "message_stop"}
|
||||
return e.marshalEvent("message_stop", payload)
|
||||
}
|
||||
|
||||
// encodePing 编码心跳事件
|
||||
func (e *StreamEncoder) encodePing() [][]byte {
|
||||
payload := map[string]any{"type": "ping"}
|
||||
return e.marshalEvent("ping", payload)
|
||||
}
|
||||
|
||||
// encodeError 编码错误事件
|
||||
func (e *StreamEncoder) encodeError(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Error == nil {
|
||||
return nil
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": event.Error.Type,
|
||||
"message": event.Error.Message,
|
||||
},
|
||||
}
|
||||
return e.marshalEvent("error", payload)
|
||||
}
|
||||
|
||||
// marshalEvent 序列化为 Anthropic 命名 SSE 事件
|
||||
func (e *StreamEncoder) marshalEvent(eventType string, payload map[string]any) [][]byte {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return [][]byte{[]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, data))}
|
||||
}
|
||||
242
backend/internal/conversion/anthropic/stream_encoder_test.go
Normal file
242
backend/internal/conversion/anthropic/stream_encoder_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEvent("msg_1", "claude-3")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
assert.Contains(t, s, "msg_1")
|
||||
assert.Contains(t, s, "claude-3")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_delta\n"))
|
||||
assert.Contains(t, s, "你好")
|
||||
|
||||
// 验证 JSON 格式
|
||||
lines := strings.Split(s, "\n")
|
||||
var dataLine string
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
dataLine = strings.TrimPrefix(l, "data: ")
|
||||
break
|
||||
}
|
||||
}
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(dataLine), &payload))
|
||||
assert.Equal(t, "content_block_delta", payload["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStopEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_stop\n"))
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
assert.Equal(t, "text", cb["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(1, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "toolu_1",
|
||||
Name: "search",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "toolu_1")
|
||||
assert.Contains(t, s, "search")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
assert.Equal(t, "tool_use", cb["type"])
|
||||
assert.Equal(t, "toolu_1", cb["id"])
|
||||
assert.Equal(t, "search", cb["name"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "thinking", Thinking: ""})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "thinking")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
assert.Equal(t, "thinking", cb["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
idx := 2
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_stop\n"))
|
||||
assert.Contains(t, s, "content_block_stop")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
sr := canonical.StopReasonEndTurn
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "stop_reason")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
delta := payload["delta"].(map[string]any)
|
||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
usage := canonical.CanonicalUsage{OutputTokens: 88}
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
Usage: &usage,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "output_tokens")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
u := payload["usage"].(map[string]any)
|
||||
assert.Equal(t, float64(88), u["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Ping(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewPingEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: ping\n"))
|
||||
assert.Contains(t, s, "ping")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Error(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewErrorEvent("overloaded_error", "服务过载")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: error\n"))
|
||||
assert.Contains(t, s, "overloaded_error")
|
||||
assert.Contains(t, s, "服务过载")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
chunks := e.Flush()
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_UnknownEvent_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.CanonicalStreamEvent{Type: "unknown_event_type"}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
183
backend/internal/conversion/anthropic/types.go
Normal file
183
backend/internal/conversion/anthropic/types.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// MessagesRequest Anthropic Messages 请求
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
// RequestMetadata 请求元数据
|
||||
type RequestMetadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig 思考配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
Display string `json:"display,omitempty"`
|
||||
}
|
||||
|
||||
// OutputConfig 输出配置
|
||||
type OutputConfig struct {
|
||||
Format *OutputFormatConfig `json:"format,omitempty"`
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// OutputFormatConfig 输出格式配置
|
||||
type OutputFormatConfig struct {
|
||||
Type string `json:"type"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
}
|
||||
|
||||
// Message Anthropic 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
// TextContent 文本内容块
|
||||
type TextContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ToolUseContent 工具调用内容块
|
||||
type ToolUseContent struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
}
|
||||
|
||||
// ToolResultContent 工具结果内容块
|
||||
type ToolResultContent struct {
|
||||
Type string `json:"type"`
|
||||
ToolUseID string `json:"tool_use_id"`
|
||||
Content any `json:"content"`
|
||||
IsError *bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingContent 思考内容块
|
||||
type ThinkingContent struct {
|
||||
Type string `json:"type"`
|
||||
Thinking string `json:"thinking"`
|
||||
}
|
||||
|
||||
// RedactedThinkingContent 已编辑思考内容块
|
||||
type RedactedThinkingContent struct {
|
||||
Type string `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// Tool Anthropic 工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// MessagesResponse Anthropic Messages 响应
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
StopDetails any `json:"stop_details,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
Usage ResponseUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// ContentBlock Anthropic 响应内容块
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseUsage 响应用量
|
||||
type ResponseUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
|
||||
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ModelsResponse Anthropic 模型列表响应
|
||||
type ModelsResponse struct {
|
||||
Data []ModelItem `json:"data"`
|
||||
HasMore bool `json:"has_more"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
}
|
||||
|
||||
// ModelItem Anthropic 模型项
|
||||
type ModelItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfoResponse Anthropic 模型详情响应
|
||||
type ModelInfoResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest Anthropic 不支持嵌入,但定义类型用于接口兼容
|
||||
type EmbeddingRequest struct{}
|
||||
|
||||
// EmbeddingResponse Anthropic 不支持嵌入
|
||||
type EmbeddingResponse struct{}
|
||||
|
||||
// RerankRequest Anthropic 不支持重排序
|
||||
type RerankRequest struct{}
|
||||
|
||||
// RerankResponse Anthropic 不支持重排序
|
||||
type RerankResponse struct{}
|
||||
|
||||
// ErrorResponse Anthropic 错误响应
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"`
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SSEEvent SSE 事件
|
||||
type SSEEvent struct {
|
||||
EventType string
|
||||
Data json.RawMessage
|
||||
}
|
||||
Reference in New Issue
Block a user