package provider import ( "bytes" "context" "fmt" "io" "net/http" "strings" "time" "go.uber.org/zap" "nex/backend/internal/protocol/openai" ) // StreamConfig 流式处理配置 type StreamConfig struct { InitialBufferSize int // 初始缓冲区大小(字节),默认 4096 MaxBufferSize int // 最大缓冲区大小(字节),默认 65536 Timeout time.Duration // 流超时时间,默认 5 分钟 ChannelBufferSize int // 事件通道缓冲区大小,默认 100 } // DefaultStreamConfig 返回默认流式处理配置 func DefaultStreamConfig() StreamConfig { return StreamConfig{ InitialBufferSize: 4096, MaxBufferSize: 65536, Timeout: 5 * time.Minute, ChannelBufferSize: 100, } } // Client OpenAI 兼容供应商客户端 type Client struct { httpClient *http.Client adapter *openai.Adapter logger *zap.Logger streamCfg StreamConfig } // StreamEvent 流事件 type StreamEvent struct { Data []byte Error error Done bool } // ProviderClient 供应商客户端接口 type ProviderClient interface { SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) } // NewClient 创建供应商客户端 func NewClient() *Client { return &Client{ httpClient: &http.Client{ Timeout: 30 * time.Second, }, adapter: openai.NewAdapter(), logger: zap.L(), streamCfg: DefaultStreamConfig(), } } // SendRequest 发送非流式请求 func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) { // 准备请求 httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL) if err != nil { return nil, fmt.Errorf("准备请求失败: %w", err) } c.logger.Debug("发送请求", zap.String("url", httpReq.URL.String()), zap.String("method", httpReq.Method), ) // 设置上下文 httpReq = httpReq.WithContext(ctx) // 发送请求 resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("发送请求失败: %w", err) } // 检查状态码 if resp.StatusCode != http.StatusOK { // 解析错误响应 errorResp, parseErr := c.adapter.ParseErrorResponse(resp) if parseErr != nil { return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) } return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message) } // 解析响应 result, err := c.adapter.ParseResponse(resp) if err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } return result, nil } // SendStreamRequest 发送流式请求 func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) { // 确保请求设置为流式 req.Stream = true // 准备请求 httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL) if err != nil { return nil, fmt.Errorf("准备请求失败: %w", err) } // 设置带超时的上下文 streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout) _ = cancel // cancel 在流读取结束后由 ctx 传播处理 httpReq = httpReq.WithContext(streamCtx) // 发送请求 resp, err := c.httpClient.Do(httpReq) if err != nil { cancel() return nil, fmt.Errorf("发送请求失败: %w", err) } // 检查状态码 if resp.StatusCode != http.StatusOK { defer resp.Body.Close() cancel() errorResp, parseErr := c.adapter.ParseErrorResponse(resp) if parseErr != nil { return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) } return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message) } // 创建事件通道 eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize) // 启动 goroutine 读取流 go c.readStream(streamCtx, cancel, resp.Body, eventChan) return eventChan, nil } // readStream 读取 SSE 流(支持动态缓冲区、超时控制和改进的错误处理) func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) { defer close(eventChan) defer body.Close() defer cancel() bufSize := c.streamCfg.InitialBufferSize buf := make([]byte, bufSize) var dataBuf []byte for { select { case <-ctx.Done(): if ctx.Err() == context.DeadlineExceeded { c.logger.Warn("流读取超时") eventChan <- StreamEvent{Error: fmt.Errorf("流读取超时: %w", ctx.Err())} } else { eventChan <- StreamEvent{Error: ctx.Err()} } return default: } n, err := body.Read(buf) if err != nil { if err == io.EOF { // 流正常结束 return } // 区分网络错误和其他错误 if isNetworkError(err) { c.logger.Error("流网络错误", zap.String("error", err.Error())) eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} } else { c.logger.Error("流读取错误", zap.String("error", err.Error())) eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)} } return } dataBuf = append(dataBuf, buf[:n]...) // 动态调整缓冲区大小:如果数据量大,增大缓冲区 if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize { newSize := bufSize * 2 if newSize > c.streamCfg.MaxBufferSize { newSize = c.streamCfg.MaxBufferSize } buf = make([]byte, newSize) bufSize = newSize } // 处理完整的 SSE 事件 for { // 查找事件边界(双换行) idx := bytes.Index(dataBuf, []byte("\n\n")) if idx == -1 { break } // 提取事件 event := dataBuf[:idx] dataBuf = dataBuf[idx+2:] // 解析 data 行 lines := strings.Split(string(event), "\n") for _, line := range lines { if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") // 检查是否是结束标记 if data == "[DONE]" { eventChan <- StreamEvent{Done: true} return } // 发送数据 eventChan <- StreamEvent{Data: []byte(data)} } } } } } // isNetworkError 判断是否为网络相关错误 func isNetworkError(err error) bool { if err == nil { return false } errStr := err.Error() return strings.Contains(errStr, "connection reset") || strings.Contains(errStr, "broken pipe") || strings.Contains(errStr, "network") || strings.Contains(errStr, "timeout") || strings.Contains(errStr, "EOF") }