package provider import ( "bytes" "context" "errors" "fmt" "io" "net" "net/http" "syscall" "time" "go.uber.org/zap" "nex/backend/internal/conversion" pkgErrors "nex/backend/pkg/errors" ) // StreamConfig 流式处理配置 type StreamConfig struct { InitialBufferSize int MaxBufferSize int Timeout time.Duration ChannelBufferSize int } // DefaultStreamConfig 返回默认流式处理配置 func DefaultStreamConfig() StreamConfig { return StreamConfig{ InitialBufferSize: 4096, MaxBufferSize: 65536, Timeout: 5 * time.Minute, ChannelBufferSize: 100, } } // StreamEvent 流事件 type StreamEvent struct { Data []byte Error error Done bool } // Client 协议无关的供应商客户端 type Client struct { httpClient *http.Client logger *zap.Logger streamCfg StreamConfig } // ProviderClient 供应商客户端接口 //go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks type ProviderClient interface { Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) } // NewClient 创建供应商客户端 func NewClient() *Client { return &Client{ httpClient: &http.Client{ Timeout: 30 * time.Second, }, logger: zap.L(), streamCfg: DefaultStreamConfig(), } } // Send 发送非流式请求 func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { var bodyReader io.Reader if len(spec.Body) > 0 { bodyReader = bytes.NewReader(spec.Body) } httpReq, err := http.NewRequestWithContext(ctx, spec.Method, spec.URL, bodyReader) if err != nil { return nil, pkgErrors.ErrRequestCreate.WithCause(err) } for k, v := range spec.Headers { httpReq.Header.Set(k, v) } c.logger.Debug("发送请求", zap.String("url", spec.URL), zap.String("method", spec.Method), ) resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, pkgErrors.ErrRequestSend.WithCause(err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, pkgErrors.ErrResponseRead.WithCause(err) } respHeaders := make(map[string]string) for k, vs := range resp.Header { if len(vs) > 0 { respHeaders[k] = vs[0] } } return &conversion.HTTPResponseSpec{ StatusCode: resp.StatusCode, Headers: respHeaders, Body: respBody, }, nil } // SendStream 发送流式请求 func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) { var bodyReader io.Reader if len(spec.Body) > 0 { bodyReader = bytes.NewReader(spec.Body) } streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout) httpReq, err := http.NewRequestWithContext(streamCtx, spec.Method, spec.URL, bodyReader) if err != nil { cancel() return nil, pkgErrors.ErrRequestCreate.WithCause(err) } for k, v := range spec.Headers { httpReq.Header.Set(k, v) } resp, err := c.httpClient.Do(httpReq) if err != nil { cancel() return nil, pkgErrors.ErrRequestSend.WithCause(err) } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() cancel() errBody, _ := io.ReadAll(resp.Body) if len(errBody) > 0 { return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody)) } return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) } eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize) go c.readStream(streamCtx, cancel, resp.Body, eventChan) return eventChan, nil } // readStream 读取 SSE 流 func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) { defer close(eventChan) defer body.Close() defer cancel() bufSize := c.streamCfg.InitialBufferSize buf := make([]byte, bufSize) var dataBuf []byte for { select { case <-ctx.Done(): if ctx.Err() == context.DeadlineExceeded { c.logger.Warn("流读取超时") eventChan <- StreamEvent{Error: fmt.Errorf("流读取超时: %w", ctx.Err())} } else { eventChan <- StreamEvent{Error: ctx.Err()} } return default: } n, err := body.Read(buf) if n > 0 { dataBuf = append(dataBuf, buf[:n]...) } if err != nil { if err != io.EOF { if isNetworkError(err) { c.logger.Error("流网络错误", zap.String("error", err.Error())) eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} } else { c.logger.Error("流读取错误", zap.String("error", err.Error())) eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)} } return } } if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize { newSize := bufSize * 2 if newSize > c.streamCfg.MaxBufferSize { newSize = c.streamCfg.MaxBufferSize } buf = make([]byte, newSize) bufSize = newSize } for { idx := bytes.Index(dataBuf, []byte("\n\n")) if idx == -1 { break } rawEvent := dataBuf[:idx] dataBuf = dataBuf[idx+2:] if bytes.Contains(rawEvent, []byte("data: [DONE]")) { eventChan <- StreamEvent{Done: true} return } eventChan <- StreamEvent{Data: rawEvent} } if err == io.EOF { return } } } // isNetworkError 判断是否为网络相关错误 func isNetworkError(err error) bool { if err == nil { return false } // 检查标准库定义的网络错误类型 var netErr net.Error if errors.As(err, &netErr) { return true } // 检查操作错误 var opErr *net.OpError if errors.As(err, &opErr) { // 检查具体的系统错误 if opErr.Err != nil { // 连接重置 if errors.Is(opErr.Err, syscall.ECONNRESET) { return true } // 断管 if errors.Is(opErr.Err, syscall.EPIPE) { return true } // 超时 if errors.Is(opErr.Err, syscall.ETIMEDOUT) { return true } } return true } // 检查上下文错误 if errors.Is(err, context.DeadlineExceeded) { return true } if errors.Is(err, context.Canceled) { return true } // 检查 EOF if errors.Is(err, io.EOF) { return true } return false }