实现支持 OpenAI 和 Anthropic 双协议的统一大模型 API 网关 MVP 版本,包含: - OpenAI 和 Anthropic 协议代理 - 供应商和模型管理 - 用量统计 - 前端配置界面
165 lines
3.8 KiB
Go
165 lines
3.8 KiB
Go
package anthropic
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
|
||
"nex/backend/internal/protocol/openai"
|
||
)
|
||
|
||
// StreamConverter 流式转换器
|
||
type StreamConverter struct {
|
||
messageID string
|
||
model string
|
||
index int // 当前 content block index
|
||
toolCallArgs map[int]string // 缓存每个 tool_call 的 arguments
|
||
sentStart bool // 是否已发送 message_start
|
||
sentBlockStart map[int]bool // 每个 index 是否已发送 content_block_start
|
||
}
|
||
|
||
// NewStreamConverter 创建流式转换器
|
||
func NewStreamConverter(messageID, model string) *StreamConverter {
|
||
return &StreamConverter{
|
||
messageID: messageID,
|
||
model: model,
|
||
index: 0,
|
||
toolCallArgs: make(map[int]string),
|
||
sentStart: false,
|
||
sentBlockStart: make(map[int]bool),
|
||
}
|
||
}
|
||
|
||
// ConvertChunk 转换 OpenAI 流块为 Anthropic 事件
|
||
func (c *StreamConverter) ConvertChunk(chunk *openai.StreamChunk) ([]StreamEvent, error) {
|
||
var events []StreamEvent
|
||
|
||
// 发送 message_start(仅一次)
|
||
if !c.sentStart {
|
||
events = append(events, StreamEvent{
|
||
Type: "message_start",
|
||
Message: &MessagesResponse{
|
||
ID: c.messageID,
|
||
Type: "message",
|
||
Role: "assistant",
|
||
Model: c.model,
|
||
Content: []ContentBlock{},
|
||
Usage: Usage{
|
||
InputTokens: 0,
|
||
OutputTokens: 0,
|
||
},
|
||
},
|
||
})
|
||
c.sentStart = true
|
||
}
|
||
|
||
// 处理每个 choice
|
||
for _, choice := range chunk.Choices {
|
||
// 处理 content delta
|
||
if choice.Delta.Content != "" {
|
||
// 发送 content_block_start(如果还没发送)
|
||
if !c.sentBlockStart[c.index] {
|
||
events = append(events, StreamEvent{
|
||
Type: "content_block_start",
|
||
Index: c.index,
|
||
ContentBlock: &ContentBlock{
|
||
Type: "text",
|
||
},
|
||
})
|
||
c.sentBlockStart[c.index] = true
|
||
}
|
||
|
||
// 发送 text delta
|
||
events = append(events, StreamEvent{
|
||
Type: "content_block_delta",
|
||
Index: c.index,
|
||
Delta: &Delta{
|
||
Type: "text_delta",
|
||
Text: choice.Delta.Content,
|
||
},
|
||
})
|
||
}
|
||
|
||
// 处理 tool_calls delta
|
||
if len(choice.Delta.ToolCalls) > 0 {
|
||
for _, tc := range choice.Delta.ToolCalls {
|
||
// 确定 tool_call index
|
||
toolIndex := c.index + len(c.toolCallArgs)
|
||
|
||
// 发送 content_block_start(如果还没发送)
|
||
if !c.sentBlockStart[toolIndex] {
|
||
events = append(events, StreamEvent{
|
||
Type: "content_block_start",
|
||
Index: toolIndex,
|
||
ContentBlock: &ContentBlock{
|
||
Type: "tool_use",
|
||
ID: tc.ID,
|
||
Name: tc.Function.Name,
|
||
},
|
||
})
|
||
c.sentBlockStart[toolIndex] = true
|
||
c.toolCallArgs[toolIndex] = ""
|
||
}
|
||
|
||
// 缓存 arguments
|
||
c.toolCallArgs[toolIndex] += tc.Function.Arguments
|
||
|
||
// 发送 input delta
|
||
events = append(events, StreamEvent{
|
||
Type: "content_block_delta",
|
||
Index: toolIndex,
|
||
Delta: &Delta{
|
||
Type: "input_json_delta",
|
||
Input: tc.Function.Arguments,
|
||
},
|
||
})
|
||
}
|
||
}
|
||
|
||
// 处理 finish_reason
|
||
if choice.FinishReason != "" {
|
||
// 发送 content_block_stop
|
||
for idx := range c.sentBlockStart {
|
||
events = append(events, StreamEvent{
|
||
Type: "content_block_stop",
|
||
Index: idx,
|
||
})
|
||
}
|
||
|
||
// 转换 stop_reason
|
||
stopReason := ""
|
||
switch choice.FinishReason {
|
||
case "stop":
|
||
stopReason = "end_turn"
|
||
case "tool_calls":
|
||
stopReason = "tool_use"
|
||
case "length":
|
||
stopReason = "max_tokens"
|
||
}
|
||
|
||
// 发送 message_delta
|
||
events = append(events, StreamEvent{
|
||
Type: "message_delta",
|
||
Delta: &Delta{
|
||
StopReason: stopReason,
|
||
},
|
||
})
|
||
|
||
// 发送 message_stop
|
||
events = append(events, StreamEvent{
|
||
Type: "message_stop",
|
||
})
|
||
}
|
||
}
|
||
|
||
return events, nil
|
||
}
|
||
|
||
// SerializeEvent 序列化事件为 SSE 格式
|
||
func SerializeEvent(event StreamEvent) (string, error) {
|
||
bytes, err := json.Marshal(event)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return fmt.Sprintf("event: %s\ndata: %s\n\n", event.Type, string(bytes)), nil
|
||
}
|