package conversion import ( "bytes" "strings" "nex/backend/internal/conversion/canonical" ) // StreamDecoder 流式解码器接口 type StreamDecoder interface { ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent Flush() []canonical.CanonicalStreamEvent } // StreamEncoder 流式编码器接口 type StreamEncoder interface { EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte Flush() [][]byte } // StreamConverter 流式转换器接口 type StreamConverter interface { ProcessChunk(rawChunk []byte) [][]byte Flush() [][]byte } // PassthroughStreamConverter 同协议透传流式转换器 type PassthroughStreamConverter struct{} // NewPassthroughStreamConverter 创建透传流式转换器 func NewPassthroughStreamConverter() *PassthroughStreamConverter { return &PassthroughStreamConverter{} } // ProcessChunk 直接传递原始字节 func (c *PassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { return [][]byte{rawChunk} } // Flush 无缓冲数据 func (c *PassthroughStreamConverter) Flush() [][]byte { return nil } // SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器 // 按 SSE frame 改写 data JSON 中的 model 字段 type SmartPassthroughStreamConverter struct { adapter ProtocolAdapter modelOverride string interfaceType InterfaceType buffer []byte } // NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器 func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter { return &SmartPassthroughStreamConverter{ adapter: adapter, modelOverride: modelOverride, interfaceType: interfaceType, } } // ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段 func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { if len(rawChunk) == 0 { return nil } c.buffer = append(c.buffer, rawChunk...) frames, rest := splitSSEFrames(c.buffer) c.buffer = rest result := make([][]byte, 0, len(frames)) for _, frame := range frames { result = append(result, c.rewriteFrame(frame)) } return result } func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte { payload, ok := sseFrameDataPayload(frame) if !ok || strings.TrimSpace(payload) == "[DONE]" { return frame } rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType) if err != nil { return frame } return rebuildSSEFrameWithData(frame, string(rewrittenPayload)) } // Flush 输出未形成完整 frame 的剩余数据 func (c *SmartPassthroughStreamConverter) Flush() [][]byte { if len(c.buffer) == 0 { return nil } frame := append([]byte(nil), c.buffer...) c.buffer = nil return [][]byte{c.rewriteFrame(frame)} } // CanonicalStreamConverter 跨协议规范流式转换器 type CanonicalStreamConverter struct { decoder StreamDecoder encoder StreamEncoder chain *MiddlewareChain ctx ConversionContext clientProtocol string providerProtocol string modelOverride string } // NewCanonicalStreamConverter 创建规范流式转换器 func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *CanonicalStreamConverter { return &CanonicalStreamConverter{ decoder: decoder, encoder: encoder, } } // NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器 func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter { return &CanonicalStreamConverter{ decoder: decoder, encoder: encoder, chain: chain, ctx: ctx, clientProtocol: clientProtocol, providerProtocol: providerProtocol, modelOverride: modelOverride, } } // ProcessChunk 解码 → 中间件 → modelOverride → 编码管道 func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { events := c.decoder.ProcessChunk(rawChunk) var result [][]byte for i := range events { if c.chain != nil { processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx) if err != nil { continue } events[i] = *processed } c.applyModelOverride(&events[i]) chunks := c.encoder.EncodeEvent(events[i]) result = append(result, chunks...) } return result } // Flush 刷新解码器和编码器缓冲区 func (c *CanonicalStreamConverter) Flush() [][]byte { events := c.decoder.Flush() var result [][]byte for i := range events { if c.chain != nil { processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx) if err != nil { continue } events[i] = *processed } c.applyModelOverride(&events[i]) chunks := c.encoder.EncodeEvent(events[i]) result = append(result, chunks...) } encoderChunks := c.encoder.Flush() result = append(result, encoderChunks...) return result } // applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段 func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) { if c.modelOverride != "" && event.Message != nil { event.Message.Model = c.modelOverride } } func splitSSEFrames(data []byte) ([][]byte, []byte) { var frames [][]byte for len(data) > 0 { idx, sepLen := findSSEFrameSeparator(data) if idx < 0 { break } end := idx + sepLen frames = append(frames, append([]byte(nil), data[:end]...)) data = data[end:] } return frames, data } func findSSEFrameSeparator(data []byte) (int, int) { lf := bytes.Index(data, []byte("\n\n")) crlf := bytes.Index(data, []byte("\r\n\r\n")) switch { case lf < 0 && crlf < 0: return -1, 0 case lf < 0: return crlf, 4 case crlf < 0: return lf, 2 case crlf <= lf: return crlf, 4 default: return lf, 2 } } func sseFrameDataPayload(frame []byte) (string, bool) { text := strings.TrimRight(string(frame), "\r\n") lines := strings.Split(text, "\n") var dataLines []string for _, line := range lines { line = strings.TrimRight(line, "\r") if strings.HasPrefix(line, "data:") { value := strings.TrimPrefix(line, "data:") if strings.HasPrefix(value, " ") { value = value[1:] } dataLines = append(dataLines, value) } } if len(dataLines) == 0 { return "", false } return strings.Join(dataLines, "\n"), true } func rebuildSSEFrameWithData(frame []byte, data string) []byte { lineEnding, separator := sseLineEnding(frame) text := strings.TrimRight(string(frame), "\r\n") lines := strings.Split(text, "\n") out := make([]string, 0, len(lines)+1) dataWritten := false for _, line := range lines { line = strings.TrimRight(line, "\r") if strings.HasPrefix(line, "data:") { if !dataWritten { for _, dataLine := range strings.Split(data, "\n") { out = append(out, "data: "+dataLine) } dataWritten = true } continue } out = append(out, line) } if !dataWritten { out = append(out, "data: "+data) } return []byte(strings.Join(out, lineEnding) + separator) } func sseLineEnding(frame []byte) (string, string) { if bytes.Contains(frame, []byte("\r\n")) { return "\r\n", "\r\n\r\n" } return "\n", "\n\n" }