1
0

fix: 修正 conversion 代理路径和错误边界

This commit is contained in:
2026-04-25 23:12:54 +08:00
parent f5c82b6980
commit 2c043c6cf7
25 changed files with 2020 additions and 214 deletions

View File

@@ -8,6 +8,7 @@ import (
"io"
"net"
"net/http"
"strings"
"syscall"
"time"
@@ -43,6 +44,14 @@ type StreamEvent struct {
Done bool
}
// StreamResponse 表示上游流式 HTTP 响应。
type StreamResponse struct {
StatusCode int
Headers map[string]string
Body []byte
Events <-chan StreamEvent
}
// Client 协议无关的供应商客户端
type Client struct {
httpClient *http.Client
@@ -55,7 +64,7 @@ type Client struct {
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
}
// NewClient 创建供应商客户端
@@ -116,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
}
// SendStream 发送流式请求
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
var bodyReader io.Reader
if len(spec.Body) > 0 {
bodyReader = bytes.NewReader(spec.Body)
@@ -139,23 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
return nil, pkgErrors.ErrRequestSend.WithCause(err)
}
if resp.StatusCode != http.StatusOK {
respHeaders := extractResponseHeaders(resp.Header)
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close()
cancel()
errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d读取错误响应失败: %w", resp.StatusCode, readErr)
return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
}
if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
}
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Body: errBody,
}, nil
}
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
return eventChan, nil
return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Events: eventChan,
}, nil
}
// readStream 读取 SSE 流
@@ -208,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
}
for {
idx := bytes.Index(dataBuf, []byte("\n\n"))
idx, sepLen := findSSEFrameSeparator(dataBuf)
if idx == -1 {
break
}
rawEvent := dataBuf[:idx]
dataBuf = dataBuf[idx+2:]
frameEnd := idx + sepLen
rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
dataBuf = dataBuf[frameEnd:]
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
if isSSEDoneFrame(rawEvent) {
eventChan <- StreamEvent{Data: rawEvent}
eventChan <- StreamEvent{Done: true}
return
}
@@ -225,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
}
if err == io.EOF {
if len(dataBuf) > 0 {
eventChan <- StreamEvent{Data: dataBuf}
}
return
}
}
}
func isSSEDoneFrame(frame []byte) bool {
payload, ok := sseFrameDataPayload(frame)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func extractResponseHeaders(header http.Header) map[string]string {
respHeaders := make(map[string]string)
for k, vs := range header {
if len(vs) > 0 {
respHeaders[k] = vs[0]
}
}
return respHeaders
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
// isNetworkError 判断是否为网络相关错误
func isNetworkError(err error) bool {
if err == nil {