package provider import ( "bytes" "context" "fmt" "io" "net/http" "strings" "time" "nex/backend/internal/protocol/openai" ) // Client OpenAI 兼容供应商客户端 type Client struct { httpClient *http.Client adapter *openai.Adapter } // NewClient 创建供应商客户端 func NewClient() *Client { return &Client{ httpClient: &http.Client{ Timeout: 30 * time.Second, // 非流式请求超时 }, adapter: openai.NewAdapter(), } } // SendRequest 发送非流式请求 func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) { // 准备请求 httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL) if err != nil { return nil, fmt.Errorf("准备请求失败: %w", err) } // 调试日志:打印完整请求信息 fmt.Printf("[DEBUG] 请求URL: %s\n", httpReq.URL.String()) fmt.Printf("[DEBUG] 请求Method: %s\n", httpReq.Method) fmt.Printf("[DEBUG] 请求Headers: %v\n", httpReq.Header) // 设置上下文 httpReq = httpReq.WithContext(ctx) // 发送请求 resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("发送请求失败: %w", err) } // 检查状态码 if resp.StatusCode != http.StatusOK { // 解析错误响应 errorResp, parseErr := c.adapter.ParseErrorResponse(resp) if parseErr != nil { return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) } return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message) } // 解析响应 result, err := c.adapter.ParseResponse(resp) if err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } return result, nil } // SendStreamRequest 发送流式请求 func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) { // 确保请求设置为流式 req.Stream = true // 准备请求 httpReq, err := c.adapter.PrepareRequest(req, apiKey, baseURL) if err != nil { return nil, fmt.Errorf("准备请求失败: %w", err) } // 设置上下文 httpReq = httpReq.WithContext(ctx) // 发送请求 resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("发送请求失败: %w", err) } // 检查状态码 if resp.StatusCode != http.StatusOK { defer resp.Body.Close() errorResp, parseErr := c.adapter.ParseErrorResponse(resp) if parseErr != nil { return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) } return nil, fmt.Errorf("供应商错误: %s", errorResp.Error.Message) } // 创建事件通道 eventChan := make(chan StreamEvent, 100) // 启动 goroutine 读取流 go c.readStream(ctx, resp.Body, eventChan) return eventChan, nil } // StreamEvent 流事件 type StreamEvent struct { Data []byte Error error Done bool } // readStream 读取 SSE 流 func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan chan<- StreamEvent) { defer close(eventChan) defer body.Close() buf := make([]byte, 4096) var dataBuf []byte for { select { case <-ctx.Done(): eventChan <- StreamEvent{Error: ctx.Err()} return default: } n, err := body.Read(buf) if err != nil { if err == io.EOF { // 流结束 return } eventChan <- StreamEvent{Error: err} return } dataBuf = append(dataBuf, buf[:n]...) // 处理完整的 SSE 事件 for { // 查找事件边界(双换行) idx := bytes.Index(dataBuf, []byte("\n\n")) if idx == -1 { break } // 提取事件 event := dataBuf[:idx] dataBuf = dataBuf[idx+2:] // 解析 data 行 lines := strings.Split(string(event), "\n") for _, line := range lines { if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") // 检查是否是结束标记 if data == "[DONE]" { eventChan <- StreamEvent{Done: true} return } // 发送数据 eventChan <- StreamEvent{Data: []byte(data)} } } } } }