package openai import ( "encoding/json" "fmt" "time" "nex/backend/internal/conversion/canonical" ) // StreamEncoder OpenAI 流式编码器 type StreamEncoder struct { bufferedStart *canonical.CanonicalStreamEvent toolCallIndexMap map[string]int nextToolCallIndex int } // NewStreamEncoder 创建 OpenAI 流式编码器 func NewStreamEncoder() *StreamEncoder { return &StreamEncoder{ toolCallIndexMap: make(map[string]int), } } // EncodeEvent 编码 Canonical 事件为 SSE chunk func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { switch event.Type { case canonical.EventMessageStart: return e.encodeMessageStart(event) case canonical.EventContentBlockStart: return e.bufferBlockStart(event) case canonical.EventContentBlockDelta: return e.encodeContentBlockDelta(event) case canonical.EventContentBlockStop: return nil case canonical.EventMessageDelta: return e.encodeMessageDelta(event) case canonical.EventMessageStop: return [][]byte{[]byte("data: [DONE]\n\n")} case canonical.EventPing, canonical.EventError: return nil } return nil } // Flush 刷新缓冲区 func (e *StreamEncoder) Flush() [][]byte { return nil } // encodeMessageStart 编码消息开始事件 func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte { id := "" model := "" if event.Message != nil { id = event.Message.ID model = event.Message.Model } chunk := map[string]any{ "id": id, "object": "chat.completion.chunk", "created": time.Now().Unix(), "model": model, "choices": []map[string]any{{ "index": 0, "delta": map[string]any{"role": "assistant"}, }}, } return e.marshalChunk(chunk) } // bufferBlockStart 缓冲 block start 事件 func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte { e.bufferedStart = &event if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" { idx := e.nextToolCallIndex e.nextToolCallIndex++ if event.ContentBlock.ID != "" { e.toolCallIndexMap[event.ContentBlock.ID] = idx } } return nil } // encodeContentBlockDelta 编码内容块增量事件 func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte { if event.Delta == nil { return nil } switch canonical.DeltaType(event.Delta.Type) { case canonical.DeltaTypeText: return e.encodeTextDelta(event) case canonical.DeltaTypeInputJSON: return e.encodeInputJSONDelta(event) case canonical.DeltaTypeThinking: return e.encodeThinkingDelta(event) } return nil } // encodeTextDelta 编码文本增量 func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte { delta := map[string]any{ "content": event.Delta.Text, } if e.bufferedStart != nil { e.bufferedStart = nil } return e.encodeDelta(delta) } // encodeInputJSONDelta 编码 JSON 输入增量 func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte { if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil { // 首次 delta,含 id 和 name start := e.bufferedStart.ContentBlock tcIdx := 0 if start.ID != "" { tcIdx = e.toolCallIndexMap[start.ID] } delta := map[string]any{ "tool_calls": []map[string]any{{ "index": tcIdx, "id": start.ID, "type": "function", "function": map[string]any{ "name": start.Name, "arguments": event.Delta.PartialJSON, }, }}, } e.bufferedStart = nil return e.encodeDelta(delta) } // 后续 delta,仅含 arguments // 使用 canonical 事件中的 index 直接映射到 OpenAI tool_calls index tcIdx := 0 if event.Index != nil { tcIdx = *event.Index } delta := map[string]any{ "tool_calls": []map[string]any{{ "index": tcIdx, "function": map[string]any{ "arguments": event.Delta.PartialJSON, }, }}, } return e.encodeDelta(delta) } // encodeThinkingDelta 编码思考增量 func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte { delta := map[string]any{ "reasoning_content": event.Delta.Thinking, } if e.bufferedStart != nil { e.bufferedStart = nil } return e.encodeDelta(delta) } // encodeMessageDelta 编码消息增量事件 func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte { var chunks [][]byte if event.StopReason != nil { fr := mapCanonicalToFinishReason(*event.StopReason) chunk := map[string]any{ "choices": []map[string]any{{ "index": 0, "delta": map[string]any{}, "finish_reason": fr, }}, } chunks = append(chunks, e.marshalChunk(chunk)...) } if event.Usage != nil { chunk := map[string]any{ "choices": []map[string]any{}, "usage": encodeUsage(*event.Usage), } chunks = append(chunks, e.marshalChunk(chunk)...) } return chunks } // encodeDelta 编码 delta 到 SSE chunk func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte { chunk := map[string]any{ "choices": []map[string]any{{ "index": 0, "delta": delta, }}, } return e.marshalChunk(chunk) } // marshalChunk 序列化 chunk 为 SSE data func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte { data, err := json.Marshal(chunk) if err != nil { return nil } return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))} }