1
0
Files
nex/backend/internal/conversion/openai/stream_encoder.go
lanyuanxiaoyao 1dac347d3b refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间
无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化
ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
2026-04-20 00:36:27 +08:00

218 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package openai
import (
"encoding/json"
"fmt"
"time"
"nex/backend/internal/conversion/canonical"
)
// StreamEncoder OpenAI 流式编码器
type StreamEncoder struct {
bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int
nextToolCallIndex int
}
// NewStreamEncoder 创建 OpenAI 流式编码器
func NewStreamEncoder() *StreamEncoder {
return &StreamEncoder{
toolCallIndexMap: make(map[string]int),
}
}
// EncodeEvent 编码 Canonical 事件为 SSE chunk
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
switch event.Type {
case canonical.EventMessageStart:
return e.encodeMessageStart(event)
case canonical.EventContentBlockStart:
return e.bufferBlockStart(event)
case canonical.EventContentBlockDelta:
return e.encodeContentBlockDelta(event)
case canonical.EventContentBlockStop:
return nil
case canonical.EventMessageDelta:
return e.encodeMessageDelta(event)
case canonical.EventMessageStop:
return [][]byte{[]byte("data: [DONE]\n\n")}
case canonical.EventPing, canonical.EventError:
return nil
}
return nil
}
// Flush 刷新缓冲区
func (e *StreamEncoder) Flush() [][]byte {
return nil
}
// encodeMessageStart 编码消息开始事件
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
id := ""
model := ""
if event.Message != nil {
id = event.Message.ID
model = event.Message.Model
}
chunk := map[string]any{
"id": id,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{
"index": 0,
"delta": map[string]any{"role": "assistant"},
}},
}
return e.marshalChunk(chunk)
}
// bufferBlockStart 缓冲 block start 事件
func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
e.bufferedStart = &event
if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" {
idx := e.nextToolCallIndex
e.nextToolCallIndex++
if event.ContentBlock.ID != "" {
e.toolCallIndexMap[event.ContentBlock.ID] = idx
}
}
return nil
}
// encodeContentBlockDelta 编码内容块增量事件
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
if event.Delta == nil {
return nil
}
switch canonical.DeltaType(event.Delta.Type) {
case canonical.DeltaTypeText:
return e.encodeTextDelta(event)
case canonical.DeltaTypeInputJSON:
return e.encodeInputJSONDelta(event)
case canonical.DeltaTypeThinking:
return e.encodeThinkingDelta(event)
}
return nil
}
// encodeTextDelta 编码文本增量
func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte {
delta := map[string]any{
"content": event.Delta.Text,
}
if e.bufferedStart != nil {
e.bufferedStart = nil
}
return e.encodeDelta(delta)
}
// encodeInputJSONDelta 编码 JSON 输入增量
func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte {
if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil {
// 首次 delta含 id 和 name
start := e.bufferedStart.ContentBlock
tcIdx := 0
if start.ID != "" {
tcIdx = e.toolCallIndexMap[start.ID]
}
delta := map[string]any{
"tool_calls": []map[string]any{{
"index": tcIdx,
"id": start.ID,
"type": "function",
"function": map[string]any{
"name": start.Name,
"arguments": event.Delta.PartialJSON,
},
}},
}
e.bufferedStart = nil
return e.encodeDelta(delta)
}
// 后续 delta仅含 arguments
// 通过 index 查找 tool call
tcIdx := 0
if event.Index != nil {
for id, idx := range e.toolCallIndexMap {
if idx == tcIdx {
_ = id
break
}
}
}
delta := map[string]any{
"tool_calls": []map[string]any{{
"index": tcIdx,
"function": map[string]any{
"arguments": event.Delta.PartialJSON,
},
}},
}
return e.encodeDelta(delta)
}
// encodeThinkingDelta 编码思考增量
func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte {
delta := map[string]any{
"reasoning_content": event.Delta.Thinking,
}
if e.bufferedStart != nil {
e.bufferedStart = nil
}
return e.encodeDelta(delta)
}
// encodeMessageDelta 编码消息增量事件
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
var chunks [][]byte
if event.StopReason != nil {
fr := mapCanonicalToFinishReason(*event.StopReason)
chunk := map[string]any{
"choices": []map[string]any{{
"index": 0,
"delta": map[string]any{},
"finish_reason": fr,
}},
}
chunks = append(chunks, e.marshalChunk(chunk)...)
}
if event.Usage != nil {
chunk := map[string]any{
"choices": []map[string]any{},
"usage": encodeUsage(*event.Usage),
}
chunks = append(chunks, e.marshalChunk(chunk)...)
}
return chunks
}
// encodeDelta 编码 delta 到 SSE chunk
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
chunk := map[string]any{
"choices": []map[string]any{{
"index": 0,
"delta": delta,
}},
}
return e.marshalChunk(chunk)
}
// marshalChunk 序列化 chunk 为 SSE data
func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte {
data, err := json.Marshal(chunk)
if err != nil {
return nil
}
return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))}
}