1
0
Files
nex/backend/internal/conversion/openai/stream_encoder.go
lanyuanxiaoyao bc1ee612d9 refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
- 新增 ConversionEngine 核心引擎,支持 OpenAI 和 Anthropic 协议转换
- 添加 stream decoder/encoder 实现
- 更新 provider client 支持新引擎
- 补充单元测试和集成测试
- 更新 specs 文档
2026-04-20 13:02:28 +08:00

213 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package openai
import (
"encoding/json"
"fmt"
"time"
"nex/backend/internal/conversion/canonical"
)
// StreamEncoder OpenAI 流式编码器
type StreamEncoder struct {
bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int
nextToolCallIndex int
}
// NewStreamEncoder 创建 OpenAI 流式编码器
func NewStreamEncoder() *StreamEncoder {
return &StreamEncoder{
toolCallIndexMap: make(map[string]int),
}
}
// EncodeEvent 编码 Canonical 事件为 SSE chunk
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
switch event.Type {
case canonical.EventMessageStart:
return e.encodeMessageStart(event)
case canonical.EventContentBlockStart:
return e.bufferBlockStart(event)
case canonical.EventContentBlockDelta:
return e.encodeContentBlockDelta(event)
case canonical.EventContentBlockStop:
return nil
case canonical.EventMessageDelta:
return e.encodeMessageDelta(event)
case canonical.EventMessageStop:
return [][]byte{[]byte("data: [DONE]\n\n")}
case canonical.EventPing, canonical.EventError:
return nil
}
return nil
}
// Flush 刷新缓冲区
func (e *StreamEncoder) Flush() [][]byte {
return nil
}
// encodeMessageStart 编码消息开始事件
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
id := ""
model := ""
if event.Message != nil {
id = event.Message.ID
model = event.Message.Model
}
chunk := map[string]any{
"id": id,
"object": "chat.completion.chunk",
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{
"index": 0,
"delta": map[string]any{"role": "assistant"},
}},
}
return e.marshalChunk(chunk)
}
// bufferBlockStart 缓冲 block start 事件
func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
e.bufferedStart = &event
if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" {
idx := e.nextToolCallIndex
e.nextToolCallIndex++
if event.ContentBlock.ID != "" {
e.toolCallIndexMap[event.ContentBlock.ID] = idx
}
}
return nil
}
// encodeContentBlockDelta 编码内容块增量事件
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
if event.Delta == nil {
return nil
}
switch canonical.DeltaType(event.Delta.Type) {
case canonical.DeltaTypeText:
return e.encodeTextDelta(event)
case canonical.DeltaTypeInputJSON:
return e.encodeInputJSONDelta(event)
case canonical.DeltaTypeThinking:
return e.encodeThinkingDelta(event)
}
return nil
}
// encodeTextDelta 编码文本增量
func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte {
delta := map[string]any{
"content": event.Delta.Text,
}
if e.bufferedStart != nil {
e.bufferedStart = nil
}
return e.encodeDelta(delta)
}
// encodeInputJSONDelta 编码 JSON 输入增量
func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte {
if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil {
// 首次 delta含 id 和 name
start := e.bufferedStart.ContentBlock
tcIdx := 0
if start.ID != "" {
tcIdx = e.toolCallIndexMap[start.ID]
}
delta := map[string]any{
"tool_calls": []map[string]any{{
"index": tcIdx,
"id": start.ID,
"type": "function",
"function": map[string]any{
"name": start.Name,
"arguments": event.Delta.PartialJSON,
},
}},
}
e.bufferedStart = nil
return e.encodeDelta(delta)
}
// 后续 delta仅含 arguments
// 使用 canonical 事件中的 index 直接映射到 OpenAI tool_calls index
tcIdx := 0
if event.Index != nil {
tcIdx = *event.Index
}
delta := map[string]any{
"tool_calls": []map[string]any{{
"index": tcIdx,
"function": map[string]any{
"arguments": event.Delta.PartialJSON,
},
}},
}
return e.encodeDelta(delta)
}
// encodeThinkingDelta 编码思考增量
func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte {
delta := map[string]any{
"reasoning_content": event.Delta.Thinking,
}
if e.bufferedStart != nil {
e.bufferedStart = nil
}
return e.encodeDelta(delta)
}
// encodeMessageDelta 编码消息增量事件
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
var chunks [][]byte
if event.StopReason != nil {
fr := mapCanonicalToFinishReason(*event.StopReason)
chunk := map[string]any{
"choices": []map[string]any{{
"index": 0,
"delta": map[string]any{},
"finish_reason": fr,
}},
}
chunks = append(chunks, e.marshalChunk(chunk)...)
}
if event.Usage != nil {
chunk := map[string]any{
"choices": []map[string]any{},
"usage": encodeUsage(*event.Usage),
}
chunks = append(chunks, e.marshalChunk(chunk)...)
}
return chunks
}
// encodeDelta 编码 delta 到 SSE chunk
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
chunk := map[string]any{
"choices": []map[string]any{{
"index": 0,
"delta": delta,
}},
}
return e.marshalChunk(chunk)
}
// marshalChunk 序列化 chunk 为 SSE data
func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte {
data, err := json.Marshal(chunk)
if err != nil {
return nil
}
return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))}
}