1
0
Files
nex/backend/internal/provider/client_test.go
lanyuanxiaoyao 280099b89c refactor: 后端日志系统重构
- 新增模块化日志器(pkg/logger/module.go)
- 新增 GORM 日志适配器
- 统一日志入口,移除所有 zap.L() 全局 logger 调用
- 字段标准化
- 启动阶段使用结构化日志
- 更新所有相关测试
2026-04-23 18:37:51 +08:00

399 lines
11 KiB
Go

package provider
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptest"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion"
)
func TestNewClient(t *testing.T) {
client := NewClient(zap.NewNop())
require.NotNil(t, client)
assert.NotNil(t, client.httpClient)
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
assert.Equal(t, 65536, client.streamCfg.MaxBufferSize)
assert.Equal(t, 100, client.streamCfg.ChannelBufferSize)
}
func TestDefaultStreamConfig(t *testing.T) {
cfg := DefaultStreamConfig()
assert.Equal(t, 4096, cfg.InitialBufferSize)
assert.Equal(t, 65536, cfg.MaxBufferSize)
assert.Equal(t, 100, cfg.ChannelBufferSize)
}
func TestClient_Send_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{
"Authorization": "Bearer test-key",
"Content-Type": "application/json",
},
Body: []byte(`{"model":"gpt-4","messages":[]}`),
}
result, err := client.Send(context.Background(), spec)
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Contains(t, string(result.Body), "test")
}
func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer bad-key"},
Body: []byte(`{}`),
}
result, err := client.Send(context.Background(), spec)
require.NoError(t, err)
assert.Equal(t, 401, result.StatusCode)
}
func TestClient_Send_ConnectionError(t *testing.T) {
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: "http://localhost:1/v1/chat/completions",
Method: "POST",
}
_, err := client.Send(context.Background(), spec)
assert.Error(t, err)
}
func TestClient_SendStream_CreatesChannel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, eventChan)
for range eventChan {
}
}
func TestClient_SendStream_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer key"},
Body: []byte(`{}`),
}
_, err := client.SendStream(context.Background(), spec)
assert.Error(t, err)
}
func TestClient_SendStream_SSEEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataEvents [][]byte
var doneEvents int
for event := range eventChan {
if event.Done {
doneEvents++
} else if event.Error != nil {
t.Fatalf("unexpected error: %v", event.Error)
} else {
dataEvents = append(dataEvents, event.Data)
}
}
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream")
assert.Contains(t, string(dataEvents[0]), "Hello")
assert.Contains(t, string(dataEvents[1]), "World")
assert.Equal(t, 1, doneEvents)
}
func TestClient_SendStream_ContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
time.Sleep(10 * time.Second)
}))
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(ctx, spec)
require.NoError(t, err)
cancel()
var gotError bool
for event := range eventChan {
if event.Error != nil {
gotError = true
}
}
assert.True(t, gotError)
}
func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`))
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/models",
Method: "GET",
Headers: map[string]string{"Authorization": "Bearer test-key"},
}
result, err := client.Send(context.Background(), spec)
require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode)
assert.Contains(t, string(result.Body), "ok")
}
func TestClient_SendStream_SlowSSE(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
flusher.Flush()
time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
time.Sleep(100 * time.Millisecond)
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataCount int
var doneCount int
for event := range eventChan {
if event.Done {
doneCount++
} else if event.Error != nil {
t.Fatalf("unexpected error: %v", event.Error)
} else {
dataCount++
}
}
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE")
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
}
func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataEvents int
var doneEvents int
for event := range eventChan {
if event.Done {
doneEvents++
} else {
dataEvents++
}
}
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE")
assert.Equal(t, 1, doneEvents)
}
func TestIsNetworkError(t *testing.T) {
// 测试 net.Error 类型
t.Run("net_error", func(t *testing.T) {
var netErr net.Error
err := context.DeadlineExceeded
assert.True(t, errors.As(err, &netErr))
assert.True(t, isNetworkError(err))
})
// 测试 io.EOF
t.Run("io_eof", func(t *testing.T) {
assert.True(t, isNetworkError(io.EOF))
})
// 测试 context 错误
t.Run("context_errors", func(t *testing.T) {
assert.True(t, isNetworkError(context.DeadlineExceeded))
assert.True(t, isNetworkError(context.Canceled))
})
// 测试 syscall 错误(包装在 net.OpError 中)
t.Run("syscall_errors", func(t *testing.T) {
// ECONNRESET
opErr := &net.OpError{
Op: "read",
Net: "tcp",
Err: syscall.ECONNRESET,
}
assert.True(t, isNetworkError(opErr))
// EPIPE
opErr = &net.OpError{
Op: "write",
Net: "tcp",
Err: syscall.EPIPE,
}
assert.True(t, isNetworkError(opErr))
})
// 测试普通错误
t.Run("normal_error", func(t *testing.T) {
assert.False(t, isNetworkError(errors.New("normal error")))
assert.False(t, isNetworkError(nil))
})
}
func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
flusher.Flush()
time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack()
if conn != nil {
conn.Close()
}
}
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Headers: map[string]string{"Authorization": "Bearer test-key"},
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var gotData bool
for event := range eventChan {
if event.Error != nil {
} else if !event.Done {
gotData = true
}
}
assert.True(t, gotData, "should have received at least one data event before error")
}