package provider import ( "context" "errors" "io" "net" "net/http" "net/http/httptest" "syscall" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "nex/backend/internal/conversion" ) func TestNewClient(t *testing.T) { client := NewClient() require.NotNil(t, client) assert.NotNil(t, client.httpClient) assert.Equal(t, 4096, client.streamCfg.InitialBufferSize) assert.Equal(t, 65536, client.streamCfg.MaxBufferSize) assert.Equal(t, 100, client.streamCfg.ChannelBufferSize) } func TestDefaultStreamConfig(t *testing.T) { cfg := DefaultStreamConfig() assert.Equal(t, 4096, cfg.InitialBufferSize) assert.Equal(t, 65536, cfg.MaxBufferSize) assert.Equal(t, 100, cfg.ChannelBufferSize) } func TestClient_Send_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "POST", r.Method) assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"id":"test","model":"gpt-4"}`)) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{ "Authorization": "Bearer test-key", "Content-Type": "application/json", }, Body: []byte(`{"model":"gpt-4","messages":[]}`), } result, err := client.Send(context.Background(), spec) require.NoError(t, err) assert.Equal(t, 200, result.StatusCode) assert.Contains(t, string(result.Body), "test") } func TestClient_Send_ErrorResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error":{"message":"Invalid API key"}}`)) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer bad-key"}, Body: []byte(`{}`), } result, err := client.Send(context.Background(), spec) require.NoError(t, err) assert.Equal(t, 401, result.StatusCode) } func TestClient_Send_ConnectionError(t *testing.T) { client := NewClient() spec := conversion.HTTPRequestSpec{ URL: "http://localhost:1/v1/chat/completions", Method: "POST", } _, err := client.Send(context.Background(), spec) assert.Error(t, err) } func TestClient_SendStream_CreatesChannel(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer test-key"}, Body: []byte(`{}`), } eventChan, err := client.SendStream(context.Background(), spec) require.NoError(t, err) require.NotNil(t, eventChan) for range eventChan { } } func TestClient_SendStream_ErrorResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer key"}, Body: []byte(`{}`), } _, err := client.SendStream(context.Background(), spec) assert.Error(t, err) } func TestClient_SendStream_SSEEvents(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) require.True(t, ok) w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")) flusher.Flush() w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")) flusher.Flush() time.Sleep(50 * time.Millisecond) w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() time.Sleep(50 * time.Millisecond) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer test-key"}, Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`), } eventChan, err := client.SendStream(context.Background(), spec) require.NoError(t, err) var dataEvents [][]byte var doneEvents int for event := range eventChan { if event.Done { doneEvents++ } else if event.Error != nil { t.Fatalf("unexpected error: %v", event.Error) } else { dataEvents = append(dataEvents, event.Data) } } assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream") assert.Contains(t, string(dataEvents[0]), "Hello") assert.Contains(t, string(dataEvents[1]), "World") assert.Equal(t, 1, doneEvents) } func TestClient_SendStream_ContextCancellation(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) time.Sleep(10 * time.Second) })) defer server.Close() ctx, cancel := context.WithCancel(context.Background()) client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer test-key"}, Body: []byte(`{}`), } eventChan, err := client.SendStream(ctx, spec) require.NoError(t, err) cancel() var gotError bool for event := range eventChan { if event.Error != nil { gotError = true } } assert.True(t, gotError) } func TestClient_Send_EmptyBody(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "GET", r.Method) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"result":"ok"}`)) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/models", Method: "GET", Headers: map[string]string{"Authorization": "Bearer test-key"}, } result, err := client.Send(context.Background(), spec) require.NoError(t, err) assert.Equal(t, 200, result.StatusCode) assert.Contains(t, string(result.Body), "ok") } func TestClient_SendStream_SlowSSE(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) require.True(t, ok) w.Write([]byte("data: {\"id\":\"1\"}\n\n")) flusher.Flush() time.Sleep(100 * time.Millisecond) w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() time.Sleep(100 * time.Millisecond) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer test-key"}, Body: []byte(`{}`), } eventChan, err := client.SendStream(context.Background(), spec) require.NoError(t, err) var dataCount int var doneCount int for event := range eventChan { if event.Done { doneCount++ } else if event.Error != nil { t.Fatalf("unexpected error: %v", event.Error) } else { dataCount++ } } assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE") assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE") } func TestClient_SendStream_SplitSSEEvents(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) require.True(t, ok) w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n")) flusher.Flush() time.Sleep(50 * time.Millisecond) w.Write([]byte("data: [DONE]\n\n")) flusher.Flush() time.Sleep(50 * time.Millisecond) })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer test-key"}, Body: []byte(`{}`), } eventChan, err := client.SendStream(context.Background(), spec) require.NoError(t, err) var dataEvents int var doneEvents int for event := range eventChan { if event.Done { doneEvents++ } else { dataEvents++ } } assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE") assert.Equal(t, 1, doneEvents) } func TestIsNetworkError(t *testing.T) { // 测试 net.Error 类型 t.Run("net_error", func(t *testing.T) { var netErr net.Error err := context.DeadlineExceeded assert.True(t, errors.As(err, &netErr)) assert.True(t, isNetworkError(err)) }) // 测试 io.EOF t.Run("io_eof", func(t *testing.T) { assert.True(t, isNetworkError(io.EOF)) }) // 测试 context 错误 t.Run("context_errors", func(t *testing.T) { assert.True(t, isNetworkError(context.DeadlineExceeded)) assert.True(t, isNetworkError(context.Canceled)) }) // 测试 syscall 错误(包装在 net.OpError 中) t.Run("syscall_errors", func(t *testing.T) { // ECONNRESET opErr := &net.OpError{ Op: "read", Net: "tcp", Err: syscall.ECONNRESET, } assert.True(t, isNetworkError(opErr)) // EPIPE opErr = &net.OpError{ Op: "write", Net: "tcp", Err: syscall.EPIPE, } assert.True(t, isNetworkError(opErr)) }) // 测试普通错误 t.Run("normal_error", func(t *testing.T) { assert.False(t, isNetworkError(errors.New("normal error"))) assert.False(t, isNetworkError(nil)) }) } func TestClient_SendStream_MidStreamNetworkError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) require.True(t, ok) w.Write([]byte("data: {\"id\":\"1\"}\n\n")) flusher.Flush() time.Sleep(50 * time.Millisecond) if hijacker, ok := w.(http.Hijacker); ok { conn, _, _ := hijacker.Hijack() if conn != nil { conn.Close() } } })) defer server.Close() client := NewClient() spec := conversion.HTTPRequestSpec{ URL: server.URL + "/v1/chat/completions", Method: "POST", Headers: map[string]string{"Authorization": "Bearer test-key"}, Body: []byte(`{}`), } eventChan, err := client.SendStream(context.Background(), spec) require.NoError(t, err) var gotData bool for event := range eventChan { if event.Error != nil { } else if !event.Done { gotData = true } } assert.True(t, gotData, "should have received at least one data event before error") }