fix: 修正 conversion 代理路径和错误边界
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -43,6 +44,14 @@ type StreamEvent struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
// StreamResponse 表示上游流式 HTTP 响应。
|
||||
type StreamResponse struct {
|
||||
StatusCode int
|
||||
Headers map[string]string
|
||||
Body []byte
|
||||
Events <-chan StreamEvent
|
||||
}
|
||||
|
||||
// Client 协议无关的供应商客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
@@ -55,7 +64,7 @@ type Client struct {
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
type ProviderClient interface {
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
|
||||
}
|
||||
|
||||
// NewClient 创建供应商客户端
|
||||
@@ -116,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
|
||||
}
|
||||
|
||||
// SendStream 发送流式请求
|
||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
|
||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
|
||||
var bodyReader io.Reader
|
||||
if len(spec.Body) > 0 {
|
||||
bodyReader = bytes.NewReader(spec.Body)
|
||||
@@ -139,23 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
|
||||
return nil, pkgErrors.ErrRequestSend.WithCause(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respHeaders := extractResponseHeaders(resp.Header)
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
defer resp.Body.Close()
|
||||
cancel()
|
||||
errBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d,读取错误响应失败: %w", resp.StatusCode, readErr)
|
||||
return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
|
||||
}
|
||||
if len(errBody) > 0 {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
return &StreamResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Body: errBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
|
||||
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
|
||||
|
||||
return eventChan, nil
|
||||
return &StreamResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Events: eventChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// readStream 读取 SSE 流
|
||||
@@ -208,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
}
|
||||
|
||||
for {
|
||||
idx := bytes.Index(dataBuf, []byte("\n\n"))
|
||||
idx, sepLen := findSSEFrameSeparator(dataBuf)
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
rawEvent := dataBuf[:idx]
|
||||
dataBuf = dataBuf[idx+2:]
|
||||
frameEnd := idx + sepLen
|
||||
rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
|
||||
dataBuf = dataBuf[frameEnd:]
|
||||
|
||||
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
|
||||
if isSSEDoneFrame(rawEvent) {
|
||||
eventChan <- StreamEvent{Data: rawEvent}
|
||||
eventChan <- StreamEvent{Done: true}
|
||||
return
|
||||
}
|
||||
@@ -225,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
if len(dataBuf) > 0 {
|
||||
eventChan <- StreamEvent{Data: dataBuf}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isSSEDoneFrame(frame []byte) bool {
|
||||
payload, ok := sseFrameDataPayload(frame)
|
||||
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||
}
|
||||
|
||||
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
var dataLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
value := strings.TrimPrefix(line, "data:")
|
||||
if strings.HasPrefix(value, " ") {
|
||||
value = value[1:]
|
||||
}
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(dataLines, "\n"), true
|
||||
}
|
||||
|
||||
func extractResponseHeaders(header http.Header) map[string]string {
|
||||
respHeaders := make(map[string]string)
|
||||
for k, vs := range header {
|
||||
if len(vs) > 0 {
|
||||
respHeaders[k] = vs[0]
|
||||
}
|
||||
}
|
||||
return respHeaders
|
||||
}
|
||||
|
||||
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||
lf := bytes.Index(data, []byte("\n\n"))
|
||||
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0 && crlf < 0:
|
||||
return -1, 0
|
||||
case lf < 0:
|
||||
return crlf, 4
|
||||
case crlf < 0:
|
||||
return lf, 2
|
||||
case crlf <= lf:
|
||||
return crlf, 4
|
||||
default:
|
||||
return lf, 2
|
||||
}
|
||||
}
|
||||
|
||||
// isNetworkError 判断是否为网络相关错误
|
||||
func isNetworkError(err error) bool {
|
||||
if err == nil {
|
||||
|
||||
@@ -110,11 +110,13 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, eventChan)
|
||||
require.NotNil(t, streamResp)
|
||||
require.Equal(t, http.StatusOK, streamResp.StatusCode)
|
||||
require.NotNil(t, streamResp.Events)
|
||||
|
||||
for range eventChan {
|
||||
for range streamResp.Events {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,8 +134,10 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
_, err := client.SendStream(context.Background(), spec)
|
||||
assert.Error(t, err)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
assert.Equal(t, http.StatusInternalServerError, streamResp.StatusCode)
|
||||
}
|
||||
|
||||
func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
@@ -164,12 +168,13 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
@@ -180,9 +185,56 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream")
|
||||
assert.Equal(t, 3, len(dataEvents), "expected 2 data frames plus DONE frame from SSE stream")
|
||||
assert.Contains(t, string(dataEvents[0]), "Hello")
|
||||
assert.Contains(t, string(dataEvents[1]), "World")
|
||||
assert.Contains(t, string(dataEvents[2]), "[DONE]")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
assert.Contains(t, string(dataEvents[0]), "\n\n")
|
||||
}
|
||||
|
||||
func TestClient_SendStream_DONEOnlyWhenDataPayloadEqualsDone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
_, err := w.Write([]byte("data: {\"text\":\"data: [DONE] is plain text\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
default:
|
||||
dataEvents = append(dataEvents, event.Data)
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, dataEvents, 2)
|
||||
assert.Contains(t, string(dataEvents[0]), "plain text")
|
||||
assert.Contains(t, string(dataEvents[1]), "[DONE]")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
}
|
||||
|
||||
@@ -203,13 +255,13 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(ctx, spec)
|
||||
streamResp, err := client.SendStream(ctx, spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
cancel()
|
||||
|
||||
var gotError bool
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
gotError = true
|
||||
}
|
||||
@@ -264,12 +316,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var dataCount int
|
||||
var doneCount int
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneCount++
|
||||
@@ -279,7 +331,7 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
dataCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE")
|
||||
assert.Equal(t, 2, dataCount, "expected 1 data frame plus DONE frame from slow SSE")
|
||||
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
|
||||
}
|
||||
|
||||
@@ -308,19 +360,19 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var dataEvents int
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Done {
|
||||
doneEvents++
|
||||
} else {
|
||||
dataEvents++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE")
|
||||
assert.Equal(t, 3, dataEvents, "expected 2 data frames plus DONE frame from split SSE")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
}
|
||||
|
||||
@@ -397,11 +449,11 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var gotData bool
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
} else if !event.Done {
|
||||
gotData = true
|
||||
|
||||
Reference in New Issue
Block a user