引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间 无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化 ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
231 lines
6.6 KiB
Go
231 lines
6.6 KiB
Go
package openai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"unicode/utf8"
|
|
|
|
"nex/backend/internal/conversion/canonical"
|
|
)
|
|
|
|
// StreamDecoder OpenAI 流式解码器
|
|
type StreamDecoder struct {
|
|
messageStarted bool
|
|
openBlocks map[int]string
|
|
textBlockIndex int
|
|
thinkingBlockIndex int
|
|
refusalBlockIndex int
|
|
toolCallIDMap map[int]string
|
|
toolCallNameMap map[int]string
|
|
nextToolCallIdx int
|
|
utf8Remainder []byte
|
|
accumulatedUsage *canonical.CanonicalUsage
|
|
}
|
|
|
|
// NewStreamDecoder 创建 OpenAI 流式解码器
|
|
func NewStreamDecoder() *StreamDecoder {
|
|
return &StreamDecoder{
|
|
openBlocks: make(map[int]string),
|
|
toolCallIDMap: make(map[int]string),
|
|
toolCallNameMap: make(map[int]string),
|
|
textBlockIndex: -1,
|
|
thinkingBlockIndex: -1,
|
|
refusalBlockIndex: -1,
|
|
}
|
|
}
|
|
|
|
// ProcessChunk 处理原始 SSE chunk
|
|
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
|
// 处理 UTF-8 残余
|
|
data := rawChunk
|
|
if len(d.utf8Remainder) > 0 {
|
|
data = append(d.utf8Remainder, rawChunk...)
|
|
d.utf8Remainder = nil
|
|
}
|
|
|
|
var events []canonical.CanonicalStreamEvent
|
|
|
|
// 解析 SSE data 行
|
|
lines := strings.Split(string(data), "\n")
|
|
for _, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
payload := strings.TrimPrefix(line, "data: ")
|
|
|
|
if payload == "[DONE]" {
|
|
events = append(events, d.flushOpenBlocks()...)
|
|
return events
|
|
}
|
|
|
|
chunkEvents := d.processDataChunk([]byte(payload))
|
|
events = append(events, chunkEvents...)
|
|
}
|
|
|
|
return events
|
|
}
|
|
|
|
// Flush 刷新解码器状态
|
|
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
|
return nil
|
|
}
|
|
|
|
// processDataChunk 处理单个 data chunk
|
|
func (d *StreamDecoder) processDataChunk(data []byte) []canonical.CanonicalStreamEvent {
|
|
// 检查 UTF-8 完整性
|
|
if !utf8.Valid(data) {
|
|
validEnd := len(data)
|
|
for !utf8.Valid(data[:validEnd]) {
|
|
validEnd--
|
|
}
|
|
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
|
|
data = data[:validEnd]
|
|
}
|
|
|
|
var chunk StreamChunk
|
|
if err := json.Unmarshal(data, &chunk); err != nil {
|
|
return nil
|
|
}
|
|
|
|
var events []canonical.CanonicalStreamEvent
|
|
|
|
// 首个 chunk: MessageStart
|
|
if !d.messageStarted {
|
|
events = append(events, canonical.NewMessageStartEvent(chunk.ID, chunk.Model))
|
|
d.messageStarted = true
|
|
}
|
|
|
|
for _, choice := range chunk.Choices {
|
|
if choice.Delta == nil {
|
|
continue
|
|
}
|
|
delta := choice.Delta
|
|
|
|
// text content
|
|
if delta.Content != nil {
|
|
text := ""
|
|
switch v := delta.Content.(type) {
|
|
case string:
|
|
text = v
|
|
default:
|
|
text = fmt.Sprintf("%v", v)
|
|
}
|
|
if text != "" {
|
|
if _, ok := d.openBlocks[d.textBlockIndex]; !ok || d.textBlockIndex < 0 {
|
|
d.textBlockIndex = d.allocateBlockIndex()
|
|
d.openBlocks[d.textBlockIndex] = "text"
|
|
events = append(events, canonical.NewContentBlockStartEvent(d.textBlockIndex,
|
|
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
|
}
|
|
events = append(events, canonical.NewContentBlockDeltaEvent(d.textBlockIndex,
|
|
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: text}))
|
|
}
|
|
}
|
|
|
|
// reasoning_content (非标准)
|
|
if delta.ReasoningContent != "" {
|
|
if _, ok := d.openBlocks[d.thinkingBlockIndex]; !ok || d.thinkingBlockIndex < 0 {
|
|
d.thinkingBlockIndex = d.allocateBlockIndex()
|
|
d.openBlocks[d.thinkingBlockIndex] = "thinking"
|
|
events = append(events, canonical.NewContentBlockStartEvent(d.thinkingBlockIndex,
|
|
canonical.StreamContentBlock{Type: "thinking", Thinking: ""}))
|
|
}
|
|
events = append(events, canonical.NewContentBlockDeltaEvent(d.thinkingBlockIndex,
|
|
canonical.StreamDelta{Type: string(canonical.DeltaTypeThinking), Thinking: delta.ReasoningContent}))
|
|
}
|
|
|
|
// refusal
|
|
if delta.Refusal != "" {
|
|
if _, ok := d.openBlocks[d.refusalBlockIndex]; !ok || d.refusalBlockIndex < 0 {
|
|
d.refusalBlockIndex = d.allocateBlockIndex()
|
|
d.openBlocks[d.refusalBlockIndex] = "text"
|
|
events = append(events, canonical.NewContentBlockStartEvent(d.refusalBlockIndex,
|
|
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
|
}
|
|
events = append(events, canonical.NewContentBlockDeltaEvent(d.refusalBlockIndex,
|
|
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: delta.Refusal}))
|
|
}
|
|
|
|
// tool_calls
|
|
if len(delta.ToolCalls) > 0 {
|
|
for _, tc := range delta.ToolCalls {
|
|
tcIdx := 0
|
|
if tc.Index != nil {
|
|
tcIdx = *tc.Index
|
|
}
|
|
|
|
if tc.ID != "" {
|
|
// 新 tool call block
|
|
d.toolCallIDMap[tcIdx] = tc.ID
|
|
if tc.Function != nil {
|
|
d.toolCallNameMap[tcIdx] = tc.Function.Name
|
|
}
|
|
blockIdx := d.allocateBlockIndex()
|
|
d.openBlocks[blockIdx] = fmt.Sprintf("tool_use_%d", tcIdx)
|
|
name := d.toolCallNameMap[tcIdx]
|
|
events = append(events, canonical.NewContentBlockStartEvent(blockIdx,
|
|
canonical.StreamContentBlock{Type: "tool_use", ID: tc.ID, Name: name}))
|
|
}
|
|
|
|
// 查找该 tool call 的 block index
|
|
blockIdx := d.findToolUseBlockIndex(tcIdx)
|
|
if tc.Function != nil && tc.Function.Arguments != "" {
|
|
events = append(events, canonical.NewContentBlockDeltaEvent(blockIdx,
|
|
canonical.StreamDelta{Type: string(canonical.DeltaTypeInputJSON), PartialJSON: tc.Function.Arguments}))
|
|
}
|
|
}
|
|
}
|
|
|
|
// finish_reason
|
|
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
|
events = append(events, d.flushOpenBlocks()...)
|
|
sr := mapFinishReason(*choice.FinishReason)
|
|
events = append(events, canonical.NewMessageDeltaEventWithUsage(sr, nil))
|
|
events = append(events, canonical.NewMessageStopEvent())
|
|
}
|
|
}
|
|
|
|
// usage chunk (choices 为空)
|
|
if len(chunk.Choices) == 0 && chunk.Usage != nil {
|
|
usage := decodeUsage(chunk.Usage)
|
|
d.accumulatedUsage = &usage
|
|
events = append(events, canonical.NewMessageDeltaEventWithUsage(canonical.StopReasonEndTurn, &usage))
|
|
}
|
|
|
|
return events
|
|
}
|
|
|
|
// allocateBlockIndex 分配 block 索引
|
|
func (d *StreamDecoder) allocateBlockIndex() int {
|
|
maxIdx := -1
|
|
for k := range d.openBlocks {
|
|
if k > maxIdx {
|
|
maxIdx = k
|
|
}
|
|
}
|
|
return maxIdx + 1
|
|
}
|
|
|
|
// findToolUseBlockIndex 查找 tool use block 索引
|
|
func (d *StreamDecoder) findToolUseBlockIndex(tcIdx int) int {
|
|
key := fmt.Sprintf("tool_use_%d", tcIdx)
|
|
for blockIdx, typ := range d.openBlocks {
|
|
if typ == key {
|
|
return blockIdx
|
|
}
|
|
}
|
|
return d.allocateBlockIndex()
|
|
}
|
|
|
|
// flushOpenBlocks 关闭所有 open blocks
|
|
func (d *StreamDecoder) flushOpenBlocks() []canonical.CanonicalStreamEvent {
|
|
var events []canonical.CanonicalStreamEvent
|
|
for idx := range d.openBlocks {
|
|
events = append(events, canonical.NewContentBlockStopEvent(idx))
|
|
}
|
|
d.openBlocks = make(map[int]string)
|
|
return events
|
|
}
|