引入 Canonical Model 和 ProtocolAdapter 架构,支持 OpenAI/Anthropic 协议间 无缝转换,统一 ProxyHandler 替代分散的 OpenAI/Anthropic Handler,简化 ProviderClient 为协议无关的 HTTP 发送器,Provider 新增 protocol 字段。
218 lines
5.3 KiB
Go
218 lines
5.3 KiB
Go
package openai
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"nex/backend/internal/conversion/canonical"
|
||
)
|
||
|
||
// StreamEncoder OpenAI 流式编码器
|
||
type StreamEncoder struct {
|
||
bufferedStart *canonical.CanonicalStreamEvent
|
||
toolCallIndexMap map[string]int
|
||
nextToolCallIndex int
|
||
}
|
||
|
||
// NewStreamEncoder 创建 OpenAI 流式编码器
|
||
func NewStreamEncoder() *StreamEncoder {
|
||
return &StreamEncoder{
|
||
toolCallIndexMap: make(map[string]int),
|
||
}
|
||
}
|
||
|
||
// EncodeEvent 编码 Canonical 事件为 SSE chunk
|
||
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||
switch event.Type {
|
||
case canonical.EventMessageStart:
|
||
return e.encodeMessageStart(event)
|
||
case canonical.EventContentBlockStart:
|
||
return e.bufferBlockStart(event)
|
||
case canonical.EventContentBlockDelta:
|
||
return e.encodeContentBlockDelta(event)
|
||
case canonical.EventContentBlockStop:
|
||
return nil
|
||
case canonical.EventMessageDelta:
|
||
return e.encodeMessageDelta(event)
|
||
case canonical.EventMessageStop:
|
||
return [][]byte{[]byte("data: [DONE]\n\n")}
|
||
case canonical.EventPing, canonical.EventError:
|
||
return nil
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Flush 刷新缓冲区
|
||
func (e *StreamEncoder) Flush() [][]byte {
|
||
return nil
|
||
}
|
||
|
||
// encodeMessageStart 编码消息开始事件
|
||
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||
id := ""
|
||
model := ""
|
||
if event.Message != nil {
|
||
id = event.Message.ID
|
||
model = event.Message.Model
|
||
}
|
||
|
||
chunk := map[string]any{
|
||
"id": id,
|
||
"object": "chat.completion.chunk",
|
||
"created": time.Now().Unix(),
|
||
"model": model,
|
||
"choices": []map[string]any{{
|
||
"index": 0,
|
||
"delta": map[string]any{"role": "assistant"},
|
||
}},
|
||
}
|
||
|
||
return e.marshalChunk(chunk)
|
||
}
|
||
|
||
// bufferBlockStart 缓冲 block start 事件
|
||
func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||
e.bufferedStart = &event
|
||
if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" {
|
||
idx := e.nextToolCallIndex
|
||
e.nextToolCallIndex++
|
||
if event.ContentBlock.ID != "" {
|
||
e.toolCallIndexMap[event.ContentBlock.ID] = idx
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// encodeContentBlockDelta 编码内容块增量事件
|
||
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||
if event.Delta == nil {
|
||
return nil
|
||
}
|
||
|
||
switch canonical.DeltaType(event.Delta.Type) {
|
||
case canonical.DeltaTypeText:
|
||
return e.encodeTextDelta(event)
|
||
case canonical.DeltaTypeInputJSON:
|
||
return e.encodeInputJSONDelta(event)
|
||
case canonical.DeltaTypeThinking:
|
||
return e.encodeThinkingDelta(event)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// encodeTextDelta 编码文本增量
|
||
func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||
delta := map[string]any{
|
||
"content": event.Delta.Text,
|
||
}
|
||
if e.bufferedStart != nil {
|
||
e.bufferedStart = nil
|
||
}
|
||
return e.encodeDelta(delta)
|
||
}
|
||
|
||
// encodeInputJSONDelta 编码 JSON 输入增量
|
||
func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||
if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil {
|
||
// 首次 delta,含 id 和 name
|
||
start := e.bufferedStart.ContentBlock
|
||
tcIdx := 0
|
||
if start.ID != "" {
|
||
tcIdx = e.toolCallIndexMap[start.ID]
|
||
}
|
||
delta := map[string]any{
|
||
"tool_calls": []map[string]any{{
|
||
"index": tcIdx,
|
||
"id": start.ID,
|
||
"type": "function",
|
||
"function": map[string]any{
|
||
"name": start.Name,
|
||
"arguments": event.Delta.PartialJSON,
|
||
},
|
||
}},
|
||
}
|
||
e.bufferedStart = nil
|
||
return e.encodeDelta(delta)
|
||
}
|
||
|
||
// 后续 delta,仅含 arguments
|
||
// 通过 index 查找 tool call
|
||
tcIdx := 0
|
||
if event.Index != nil {
|
||
for id, idx := range e.toolCallIndexMap {
|
||
if idx == tcIdx {
|
||
_ = id
|
||
break
|
||
}
|
||
}
|
||
}
|
||
delta := map[string]any{
|
||
"tool_calls": []map[string]any{{
|
||
"index": tcIdx,
|
||
"function": map[string]any{
|
||
"arguments": event.Delta.PartialJSON,
|
||
},
|
||
}},
|
||
}
|
||
return e.encodeDelta(delta)
|
||
}
|
||
|
||
// encodeThinkingDelta 编码思考增量
|
||
func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||
delta := map[string]any{
|
||
"reasoning_content": event.Delta.Thinking,
|
||
}
|
||
if e.bufferedStart != nil {
|
||
e.bufferedStart = nil
|
||
}
|
||
return e.encodeDelta(delta)
|
||
}
|
||
|
||
// encodeMessageDelta 编码消息增量事件
|
||
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||
var chunks [][]byte
|
||
|
||
if event.StopReason != nil {
|
||
fr := mapCanonicalToFinishReason(*event.StopReason)
|
||
chunk := map[string]any{
|
||
"choices": []map[string]any{{
|
||
"index": 0,
|
||
"delta": map[string]any{},
|
||
"finish_reason": fr,
|
||
}},
|
||
}
|
||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||
}
|
||
|
||
if event.Usage != nil {
|
||
chunk := map[string]any{
|
||
"choices": []map[string]any{},
|
||
"usage": encodeUsage(*event.Usage),
|
||
}
|
||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||
}
|
||
|
||
return chunks
|
||
}
|
||
|
||
// encodeDelta 编码 delta 到 SSE chunk
|
||
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
|
||
chunk := map[string]any{
|
||
"choices": []map[string]any{{
|
||
"index": 0,
|
||
"delta": delta,
|
||
}},
|
||
}
|
||
return e.marshalChunk(chunk)
|
||
}
|
||
|
||
// marshalChunk 序列化 chunk 为 SSE data
|
||
func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte {
|
||
data, err := json.Marshal(chunk)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))}
|
||
}
|