1
0
Files
nex/backend/internal/provider/client.go
lanyuanxiaoyao 4c6b49099d feat: 配置 golangci-lint 静态分析并修复存量违规
- 新增 backend/.golangci.yml 配置 12 个 linter(forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、goimports、gocyclo)
- 新增 lefthook.yml 配置 pre-commit hook 自动运行 lint
- 修复存量代码违规:errors.Is/As 替换、zap.Error 替换、import 排序、errcheck 修复
- 更新 README 补充编码规范说明
- 归档 backend-code-lint 变更
2026-04-24 13:01:48 +08:00

281 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package provider
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"syscall"
"time"
"go.uber.org/zap"
"nex/backend/internal/conversion"
pkgErrors "nex/backend/pkg/errors"
pkglogger "nex/backend/pkg/logger"
)
// StreamConfig 流式处理配置
type StreamConfig struct {
InitialBufferSize int
MaxBufferSize int
Timeout time.Duration
ChannelBufferSize int
}
// DefaultStreamConfig 返回默认流式处理配置
func DefaultStreamConfig() StreamConfig {
return StreamConfig{
InitialBufferSize: 4096,
MaxBufferSize: 65536,
Timeout: 5 * time.Minute,
ChannelBufferSize: 100,
}
}
// StreamEvent 流事件
type StreamEvent struct {
Data []byte
Error error
Done bool
}
// Client 协议无关的供应商客户端
type Client struct {
httpClient *http.Client
logger *zap.Logger
streamCfg StreamConfig
}
// ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
}
// NewClient 创建供应商客户端
func NewClient(logger *zap.Logger) *Client {
return &Client{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
logger: pkglogger.WithModule(logger, "provider.client"),
streamCfg: DefaultStreamConfig(),
}
}
// Send 发送非流式请求
func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var bodyReader io.Reader
if len(spec.Body) > 0 {
bodyReader = bytes.NewReader(spec.Body)
}
httpReq, err := http.NewRequestWithContext(ctx, spec.Method, spec.URL, bodyReader)
if err != nil {
return nil, pkgErrors.ErrRequestCreate.WithCause(err)
}
for k, v := range spec.Headers {
httpReq.Header.Set(k, v)
}
c.logger.Debug("发送请求",
zap.String("url", spec.URL),
zap.String("method", spec.Method),
)
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, pkgErrors.ErrRequestSend.WithCause(err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, pkgErrors.ErrResponseRead.WithCause(err)
}
respHeaders := make(map[string]string)
for k, vs := range resp.Header {
if len(vs) > 0 {
respHeaders[k] = vs[0]
}
}
return &conversion.HTTPResponseSpec{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Body: respBody,
}, nil
}
// SendStream 发送流式请求
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
var bodyReader io.Reader
if len(spec.Body) > 0 {
bodyReader = bytes.NewReader(spec.Body)
}
streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout)
httpReq, err := http.NewRequestWithContext(streamCtx, spec.Method, spec.URL, bodyReader)
if err != nil {
cancel()
return nil, pkgErrors.ErrRequestCreate.WithCause(err)
}
for k, v := range spec.Headers {
httpReq.Header.Set(k, v)
}
resp, err := c.httpClient.Do(httpReq)
if err != nil {
cancel()
return nil, pkgErrors.ErrRequestSend.WithCause(err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
cancel()
errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d读取错误响应失败: %w", resp.StatusCode, readErr)
}
if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
}
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
}
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
return eventChan, nil
}
// readStream 读取 SSE 流
func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) {
defer close(eventChan)
defer body.Close()
defer cancel()
bufSize := c.streamCfg.InitialBufferSize
buf := make([]byte, bufSize)
var dataBuf []byte
for {
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
c.logger.Warn("流读取超时")
eventChan <- StreamEvent{Error: fmt.Errorf("流读取超时: %w", ctx.Err())}
} else {
eventChan <- StreamEvent{Error: ctx.Err()}
}
return
default:
}
n, err := body.Read(buf)
if n > 0 {
dataBuf = append(dataBuf, buf[:n]...)
}
if err != nil {
if err != io.EOF {
if isNetworkError(err) {
c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else {
c.logger.Error("流读取错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
}
return
}
}
if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize {
newSize := bufSize * 2
if newSize > c.streamCfg.MaxBufferSize {
newSize = c.streamCfg.MaxBufferSize
}
buf = make([]byte, newSize)
bufSize = newSize
}
for {
idx := bytes.Index(dataBuf, []byte("\n\n"))
if idx == -1 {
break
}
rawEvent := dataBuf[:idx]
dataBuf = dataBuf[idx+2:]
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
eventChan <- StreamEvent{Done: true}
return
}
eventChan <- StreamEvent{Data: rawEvent}
}
if err == io.EOF {
return
}
}
}
// isNetworkError 判断是否为网络相关错误
func isNetworkError(err error) bool {
if err == nil {
return false
}
// 检查标准库定义的网络错误类型
var netErr net.Error
if errors.As(err, &netErr) {
return true
}
// 检查操作错误
var opErr *net.OpError
if errors.As(err, &opErr) {
// 检查具体的系统错误
if opErr.Err != nil {
// 连接重置
if errors.Is(opErr.Err, syscall.ECONNRESET) {
return true
}
// 断管
if errors.Is(opErr.Err, syscall.EPIPE) {
return true
}
// 超时
if errors.Is(opErr.Err, syscall.ETIMEDOUT) {
return true
}
}
return true
}
// 检查上下文错误
if errors.Is(err, context.DeadlineExceeded) {
return true
}
if errors.Is(err, context.Canceled) {
return true
}
// 检查 EOF
if errors.Is(err, io.EOF) {
return true
}
return false
}