266 lines
7.2 KiB
Go
266 lines
7.2 KiB
Go
package conversion
|
|
|
|
import (
|
|
"bytes"
|
|
"strings"
|
|
|
|
"nex/backend/internal/conversion/canonical"
|
|
)
|
|
|
|
// StreamDecoder 流式解码器接口
|
|
type StreamDecoder interface {
|
|
ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent
|
|
Flush() []canonical.CanonicalStreamEvent
|
|
}
|
|
|
|
// StreamEncoder 流式编码器接口
|
|
type StreamEncoder interface {
|
|
EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte
|
|
Flush() [][]byte
|
|
}
|
|
|
|
// StreamConverter 流式转换器接口
|
|
type StreamConverter interface {
|
|
ProcessChunk(rawChunk []byte) [][]byte
|
|
Flush() [][]byte
|
|
}
|
|
|
|
// PassthroughStreamConverter 同协议透传流式转换器
|
|
type PassthroughStreamConverter struct{}
|
|
|
|
// NewPassthroughStreamConverter 创建透传流式转换器
|
|
func NewPassthroughStreamConverter() *PassthroughStreamConverter {
|
|
return &PassthroughStreamConverter{}
|
|
}
|
|
|
|
// ProcessChunk 直接传递原始字节
|
|
func (c *PassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
|
return [][]byte{rawChunk}
|
|
}
|
|
|
|
// Flush 无缓冲数据
|
|
func (c *PassthroughStreamConverter) Flush() [][]byte {
|
|
return nil
|
|
}
|
|
|
|
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
|
// 按 SSE frame 改写 data JSON 中的 model 字段
|
|
type SmartPassthroughStreamConverter struct {
|
|
adapter ProtocolAdapter
|
|
modelOverride string
|
|
interfaceType InterfaceType
|
|
buffer []byte
|
|
}
|
|
|
|
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
|
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
|
|
return &SmartPassthroughStreamConverter{
|
|
adapter: adapter,
|
|
modelOverride: modelOverride,
|
|
interfaceType: interfaceType,
|
|
}
|
|
}
|
|
|
|
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
|
|
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
|
if len(rawChunk) == 0 {
|
|
return nil
|
|
}
|
|
|
|
c.buffer = append(c.buffer, rawChunk...)
|
|
frames, rest := splitSSEFrames(c.buffer)
|
|
c.buffer = rest
|
|
|
|
result := make([][]byte, 0, len(frames))
|
|
for _, frame := range frames {
|
|
result = append(result, c.rewriteFrame(frame))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
|
|
payload, ok := sseFrameDataPayload(frame)
|
|
if !ok || strings.TrimSpace(payload) == "[DONE]" {
|
|
return frame
|
|
}
|
|
|
|
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
|
|
if err != nil {
|
|
return frame
|
|
}
|
|
|
|
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
|
|
}
|
|
|
|
// Flush 输出未形成完整 frame 的剩余数据
|
|
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
|
if len(c.buffer) == 0 {
|
|
return nil
|
|
}
|
|
frame := append([]byte(nil), c.buffer...)
|
|
c.buffer = nil
|
|
return [][]byte{c.rewriteFrame(frame)}
|
|
}
|
|
|
|
// CanonicalStreamConverter 跨协议规范流式转换器
|
|
type CanonicalStreamConverter struct {
|
|
decoder StreamDecoder
|
|
encoder StreamEncoder
|
|
chain *MiddlewareChain
|
|
ctx ConversionContext
|
|
clientProtocol string
|
|
providerProtocol string
|
|
modelOverride string
|
|
}
|
|
|
|
// NewCanonicalStreamConverter 创建规范流式转换器
|
|
func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *CanonicalStreamConverter {
|
|
return &CanonicalStreamConverter{
|
|
decoder: decoder,
|
|
encoder: encoder,
|
|
}
|
|
}
|
|
|
|
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
|
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
|
|
return &CanonicalStreamConverter{
|
|
decoder: decoder,
|
|
encoder: encoder,
|
|
chain: chain,
|
|
ctx: ctx,
|
|
clientProtocol: clientProtocol,
|
|
providerProtocol: providerProtocol,
|
|
modelOverride: modelOverride,
|
|
}
|
|
}
|
|
|
|
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
|
|
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
|
events := c.decoder.ProcessChunk(rawChunk)
|
|
var result [][]byte
|
|
for i := range events {
|
|
if c.chain != nil {
|
|
processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
events[i] = *processed
|
|
}
|
|
c.applyModelOverride(&events[i])
|
|
chunks := c.encoder.EncodeEvent(events[i])
|
|
result = append(result, chunks...)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// Flush 刷新解码器和编码器缓冲区
|
|
func (c *CanonicalStreamConverter) Flush() [][]byte {
|
|
events := c.decoder.Flush()
|
|
var result [][]byte
|
|
for i := range events {
|
|
if c.chain != nil {
|
|
processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
events[i] = *processed
|
|
}
|
|
c.applyModelOverride(&events[i])
|
|
chunks := c.encoder.EncodeEvent(events[i])
|
|
result = append(result, chunks...)
|
|
}
|
|
encoderChunks := c.encoder.Flush()
|
|
result = append(result, encoderChunks...)
|
|
return result
|
|
}
|
|
|
|
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
|
|
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
|
|
if c.modelOverride != "" && event.Message != nil {
|
|
event.Message.Model = c.modelOverride
|
|
}
|
|
}
|
|
|
|
func splitSSEFrames(data []byte) ([][]byte, []byte) {
|
|
var frames [][]byte
|
|
for len(data) > 0 {
|
|
idx, sepLen := findSSEFrameSeparator(data)
|
|
if idx < 0 {
|
|
break
|
|
}
|
|
end := idx + sepLen
|
|
frames = append(frames, append([]byte(nil), data[:end]...))
|
|
data = data[end:]
|
|
}
|
|
return frames, data
|
|
}
|
|
|
|
func findSSEFrameSeparator(data []byte) (int, int) {
|
|
lf := bytes.Index(data, []byte("\n\n"))
|
|
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
|
switch {
|
|
case lf < 0 && crlf < 0:
|
|
return -1, 0
|
|
case lf < 0:
|
|
return crlf, 4
|
|
case crlf < 0:
|
|
return lf, 2
|
|
case crlf <= lf:
|
|
return crlf, 4
|
|
default:
|
|
return lf, 2
|
|
}
|
|
}
|
|
|
|
func sseFrameDataPayload(frame []byte) (string, bool) {
|
|
text := strings.TrimRight(string(frame), "\r\n")
|
|
lines := strings.Split(text, "\n")
|
|
var dataLines []string
|
|
for _, line := range lines {
|
|
line = strings.TrimRight(line, "\r")
|
|
if strings.HasPrefix(line, "data:") {
|
|
value := strings.TrimPrefix(line, "data:")
|
|
if strings.HasPrefix(value, " ") {
|
|
value = value[1:]
|
|
}
|
|
dataLines = append(dataLines, value)
|
|
}
|
|
}
|
|
if len(dataLines) == 0 {
|
|
return "", false
|
|
}
|
|
return strings.Join(dataLines, "\n"), true
|
|
}
|
|
|
|
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
|
|
lineEnding, separator := sseLineEnding(frame)
|
|
text := strings.TrimRight(string(frame), "\r\n")
|
|
lines := strings.Split(text, "\n")
|
|
out := make([]string, 0, len(lines)+1)
|
|
dataWritten := false
|
|
for _, line := range lines {
|
|
line = strings.TrimRight(line, "\r")
|
|
if strings.HasPrefix(line, "data:") {
|
|
if !dataWritten {
|
|
for _, dataLine := range strings.Split(data, "\n") {
|
|
out = append(out, "data: "+dataLine)
|
|
}
|
|
dataWritten = true
|
|
}
|
|
continue
|
|
}
|
|
out = append(out, line)
|
|
}
|
|
if !dataWritten {
|
|
out = append(out, "data: "+data)
|
|
}
|
|
return []byte(strings.Join(out, lineEnding) + separator)
|
|
}
|
|
|
|
func sseLineEnding(frame []byte) (string, string) {
|
|
if bytes.Contains(frame, []byte("\r\n")) {
|
|
return "\r\n", "\r\n\r\n"
|
|
}
|
|
return "\n", "\n\n"
|
|
}
|