package openai import ( "encoding/json" "fmt" "strings" "unicode/utf8" "nex/backend/internal/conversion/canonical" ) // StreamDecoder OpenAI 流式解码器 type StreamDecoder struct { messageStarted bool openBlocks map[int]string textBlockIndex int thinkingBlockIndex int refusalBlockIndex int toolCallIDMap map[int]string toolCallNameMap map[int]string nextToolCallIdx int utf8Remainder []byte accumulatedUsage *canonical.CanonicalUsage } // NewStreamDecoder 创建 OpenAI 流式解码器 func NewStreamDecoder() *StreamDecoder { return &StreamDecoder{ openBlocks: make(map[int]string), toolCallIDMap: make(map[int]string), toolCallNameMap: make(map[int]string), textBlockIndex: -1, thinkingBlockIndex: -1, refusalBlockIndex: -1, } } // ProcessChunk 处理原始 SSE chunk func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { // 处理 UTF-8 残余 data := rawChunk if len(d.utf8Remainder) > 0 { data = append(d.utf8Remainder, rawChunk...) d.utf8Remainder = nil } var events []canonical.CanonicalStreamEvent // 解析 SSE data 行 lines := strings.Split(string(data), "\n") for _, line := range lines { line = strings.TrimSpace(line) if !strings.HasPrefix(line, "data: ") { continue } payload := strings.TrimPrefix(line, "data: ") if payload == "[DONE]" { events = append(events, d.flushOpenBlocks()...) return events } chunkEvents := d.processDataChunk([]byte(payload)) events = append(events, chunkEvents...) } return events } // Flush 刷新解码器状态 func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } // processDataChunk 处理单个 data chunk func (d *StreamDecoder) processDataChunk(data []byte) []canonical.CanonicalStreamEvent { // 检查 UTF-8 完整性 if !utf8.Valid(data) { validEnd := len(data) for !utf8.Valid(data[:validEnd]) { validEnd-- } d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...) data = data[:validEnd] } var chunk StreamChunk if err := json.Unmarshal(data, &chunk); err != nil { return nil } var events []canonical.CanonicalStreamEvent // 首个 chunk: MessageStart if !d.messageStarted { events = append(events, canonical.NewMessageStartEvent(chunk.ID, chunk.Model)) d.messageStarted = true } for _, choice := range chunk.Choices { if choice.Delta == nil { continue } delta := choice.Delta // text content if delta.Content != nil { text := "" switch v := delta.Content.(type) { case string: text = v default: text = fmt.Sprintf("%v", v) } if text != "" { if _, ok := d.openBlocks[d.textBlockIndex]; !ok || d.textBlockIndex < 0 { d.textBlockIndex = d.allocateBlockIndex() d.openBlocks[d.textBlockIndex] = "text" events = append(events, canonical.NewContentBlockStartEvent(d.textBlockIndex, canonical.StreamContentBlock{Type: "text", Text: ""})) } events = append(events, canonical.NewContentBlockDeltaEvent(d.textBlockIndex, canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: text})) } } // reasoning_content (非标准) if delta.ReasoningContent != "" { if _, ok := d.openBlocks[d.thinkingBlockIndex]; !ok || d.thinkingBlockIndex < 0 { d.thinkingBlockIndex = d.allocateBlockIndex() d.openBlocks[d.thinkingBlockIndex] = "thinking" events = append(events, canonical.NewContentBlockStartEvent(d.thinkingBlockIndex, canonical.StreamContentBlock{Type: "thinking", Thinking: ""})) } events = append(events, canonical.NewContentBlockDeltaEvent(d.thinkingBlockIndex, canonical.StreamDelta{Type: string(canonical.DeltaTypeThinking), Thinking: delta.ReasoningContent})) } // refusal if delta.Refusal != "" { if _, ok := d.openBlocks[d.refusalBlockIndex]; !ok || d.refusalBlockIndex < 0 { d.refusalBlockIndex = d.allocateBlockIndex() d.openBlocks[d.refusalBlockIndex] = "text" events = append(events, canonical.NewContentBlockStartEvent(d.refusalBlockIndex, canonical.StreamContentBlock{Type: "text", Text: ""})) } events = append(events, canonical.NewContentBlockDeltaEvent(d.refusalBlockIndex, canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: delta.Refusal})) } // tool_calls if len(delta.ToolCalls) > 0 { for _, tc := range delta.ToolCalls { tcIdx := 0 if tc.Index != nil { tcIdx = *tc.Index } if tc.ID != "" { // 新 tool call block d.toolCallIDMap[tcIdx] = tc.ID if tc.Function != nil { d.toolCallNameMap[tcIdx] = tc.Function.Name } blockIdx := d.allocateBlockIndex() d.openBlocks[blockIdx] = fmt.Sprintf("tool_use_%d", tcIdx) name := d.toolCallNameMap[tcIdx] events = append(events, canonical.NewContentBlockStartEvent(blockIdx, canonical.StreamContentBlock{Type: "tool_use", ID: tc.ID, Name: name})) } // 查找该 tool call 的 block index blockIdx := d.findToolUseBlockIndex(tcIdx) if tc.Function != nil && tc.Function.Arguments != "" { events = append(events, canonical.NewContentBlockDeltaEvent(blockIdx, canonical.StreamDelta{Type: string(canonical.DeltaTypeInputJSON), PartialJSON: tc.Function.Arguments})) } } } // finish_reason if choice.FinishReason != nil && *choice.FinishReason != "" { events = append(events, d.flushOpenBlocks()...) sr := mapFinishReason(*choice.FinishReason) events = append(events, canonical.NewMessageDeltaEventWithUsage(sr, nil)) events = append(events, canonical.NewMessageStopEvent()) } } // usage chunk (choices 为空) if len(chunk.Choices) == 0 && chunk.Usage != nil { usage := decodeUsage(chunk.Usage) d.accumulatedUsage = &usage events = append(events, canonical.NewMessageDeltaEventWithUsage(canonical.StopReasonEndTurn, &usage)) } return events } // allocateBlockIndex 分配 block 索引 func (d *StreamDecoder) allocateBlockIndex() int { maxIdx := -1 for k := range d.openBlocks { if k > maxIdx { maxIdx = k } } return maxIdx + 1 } // findToolUseBlockIndex 查找 tool use block 索引 func (d *StreamDecoder) findToolUseBlockIndex(tcIdx int) int { key := fmt.Sprintf("tool_use_%d", tcIdx) for blockIdx, typ := range d.openBlocks { if typ == key { return blockIdx } } return d.allocateBlockIndex() } // flushOpenBlocks 关闭所有 open blocks func (d *StreamDecoder) flushOpenBlocks() []canonical.CanonicalStreamEvent { var events []canonical.CanonicalStreamEvent for idx := range d.openBlocks { events = append(events, canonical.NewContentBlockStopEvent(idx)) } d.openBlocks = make(map[int]string) return events }