1
0
Files
nex/backend/internal/conversion/openai/stream_decoder.go
lanyuanxiaoyao 1dac347d3b refactor: 实现 ConversionEngine 协议转换引擎,替代旧 protocol 包
引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间
无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化
ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
2026-04-20 00:36:27 +08:00

231 lines
6.6 KiB
Go

package openai
import (
"encoding/json"
"fmt"
"strings"
"unicode/utf8"
"nex/backend/internal/conversion/canonical"
)
// StreamDecoder OpenAI 流式解码器
type StreamDecoder struct {
messageStarted bool
openBlocks map[int]string
textBlockIndex int
thinkingBlockIndex int
refusalBlockIndex int
toolCallIDMap map[int]string
toolCallNameMap map[int]string
nextToolCallIdx int
utf8Remainder []byte
accumulatedUsage *canonical.CanonicalUsage
}
// NewStreamDecoder 创建 OpenAI 流式解码器
func NewStreamDecoder() *StreamDecoder {
return &StreamDecoder{
openBlocks: make(map[int]string),
toolCallIDMap: make(map[int]string),
toolCallNameMap: make(map[int]string),
textBlockIndex: -1,
thinkingBlockIndex: -1,
refusalBlockIndex: -1,
}
}
// ProcessChunk 处理原始 SSE chunk
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
// 处理 UTF-8 残余
data := rawChunk
if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...)
d.utf8Remainder = nil
}
var events []canonical.CanonicalStreamEvent
// 解析 SSE data 行
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimPrefix(line, "data: ")
if payload == "[DONE]" {
events = append(events, d.flushOpenBlocks()...)
return events
}
chunkEvents := d.processDataChunk([]byte(payload))
events = append(events, chunkEvents...)
}
return events
}
// Flush 刷新解码器状态
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
return nil
}
// processDataChunk 处理单个 data chunk
func (d *StreamDecoder) processDataChunk(data []byte) []canonical.CanonicalStreamEvent {
// 检查 UTF-8 完整性
if !utf8.Valid(data) {
validEnd := len(data)
for !utf8.Valid(data[:validEnd]) {
validEnd--
}
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
data = data[:validEnd]
}
var chunk StreamChunk
if err := json.Unmarshal(data, &chunk); err != nil {
return nil
}
var events []canonical.CanonicalStreamEvent
// 首个 chunk: MessageStart
if !d.messageStarted {
events = append(events, canonical.NewMessageStartEvent(chunk.ID, chunk.Model))
d.messageStarted = true
}
for _, choice := range chunk.Choices {
if choice.Delta == nil {
continue
}
delta := choice.Delta
// text content
if delta.Content != nil {
text := ""
switch v := delta.Content.(type) {
case string:
text = v
default:
text = fmt.Sprintf("%v", v)
}
if text != "" {
if _, ok := d.openBlocks[d.textBlockIndex]; !ok || d.textBlockIndex < 0 {
d.textBlockIndex = d.allocateBlockIndex()
d.openBlocks[d.textBlockIndex] = "text"
events = append(events, canonical.NewContentBlockStartEvent(d.textBlockIndex,
canonical.StreamContentBlock{Type: "text", Text: ""}))
}
events = append(events, canonical.NewContentBlockDeltaEvent(d.textBlockIndex,
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: text}))
}
}
// reasoning_content (非标准)
if delta.ReasoningContent != "" {
if _, ok := d.openBlocks[d.thinkingBlockIndex]; !ok || d.thinkingBlockIndex < 0 {
d.thinkingBlockIndex = d.allocateBlockIndex()
d.openBlocks[d.thinkingBlockIndex] = "thinking"
events = append(events, canonical.NewContentBlockStartEvent(d.thinkingBlockIndex,
canonical.StreamContentBlock{Type: "thinking", Thinking: ""}))
}
events = append(events, canonical.NewContentBlockDeltaEvent(d.thinkingBlockIndex,
canonical.StreamDelta{Type: string(canonical.DeltaTypeThinking), Thinking: delta.ReasoningContent}))
}
// refusal
if delta.Refusal != "" {
if _, ok := d.openBlocks[d.refusalBlockIndex]; !ok || d.refusalBlockIndex < 0 {
d.refusalBlockIndex = d.allocateBlockIndex()
d.openBlocks[d.refusalBlockIndex] = "text"
events = append(events, canonical.NewContentBlockStartEvent(d.refusalBlockIndex,
canonical.StreamContentBlock{Type: "text", Text: ""}))
}
events = append(events, canonical.NewContentBlockDeltaEvent(d.refusalBlockIndex,
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: delta.Refusal}))
}
// tool_calls
if len(delta.ToolCalls) > 0 {
for _, tc := range delta.ToolCalls {
tcIdx := 0
if tc.Index != nil {
tcIdx = *tc.Index
}
if tc.ID != "" {
// 新 tool call block
d.toolCallIDMap[tcIdx] = tc.ID
if tc.Function != nil {
d.toolCallNameMap[tcIdx] = tc.Function.Name
}
blockIdx := d.allocateBlockIndex()
d.openBlocks[blockIdx] = fmt.Sprintf("tool_use_%d", tcIdx)
name := d.toolCallNameMap[tcIdx]
events = append(events, canonical.NewContentBlockStartEvent(blockIdx,
canonical.StreamContentBlock{Type: "tool_use", ID: tc.ID, Name: name}))
}
// 查找该 tool call 的 block index
blockIdx := d.findToolUseBlockIndex(tcIdx)
if tc.Function != nil && tc.Function.Arguments != "" {
events = append(events, canonical.NewContentBlockDeltaEvent(blockIdx,
canonical.StreamDelta{Type: string(canonical.DeltaTypeInputJSON), PartialJSON: tc.Function.Arguments}))
}
}
}
// finish_reason
if choice.FinishReason != nil && *choice.FinishReason != "" {
events = append(events, d.flushOpenBlocks()...)
sr := mapFinishReason(*choice.FinishReason)
events = append(events, canonical.NewMessageDeltaEventWithUsage(sr, nil))
events = append(events, canonical.NewMessageStopEvent())
}
}
// usage chunk (choices 为空)
if len(chunk.Choices) == 0 && chunk.Usage != nil {
usage := decodeUsage(chunk.Usage)
d.accumulatedUsage = &usage
events = append(events, canonical.NewMessageDeltaEventWithUsage(canonical.StopReasonEndTurn, &usage))
}
return events
}
// allocateBlockIndex 分配 block 索引
func (d *StreamDecoder) allocateBlockIndex() int {
maxIdx := -1
for k := range d.openBlocks {
if k > maxIdx {
maxIdx = k
}
}
return maxIdx + 1
}
// findToolUseBlockIndex 查找 tool use block 索引
func (d *StreamDecoder) findToolUseBlockIndex(tcIdx int) int {
key := fmt.Sprintf("tool_use_%d", tcIdx)
for blockIdx, typ := range d.openBlocks {
if typ == key {
return blockIdx
}
}
return d.allocateBlockIndex()
}
// flushOpenBlocks 关闭所有 open blocks
func (d *StreamDecoder) flushOpenBlocks() []canonical.CanonicalStreamEvent {
var events []canonical.CanonicalStreamEvent
for idx := range d.openBlocks {
events = append(events, canonical.NewContentBlockStopEvent(idx))
}
d.openBlocks = make(map[int]string)
return events
}