package provider import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "nex/backend/internal/protocol/openai" ) func TestNewClient(t *testing.T) { client := NewClient() require.NotNil(t, client) assert.NotNil(t, client.httpClient) assert.NotNil(t, client.adapter) assert.Equal(t, 4096, client.streamCfg.InitialBufferSize) assert.Equal(t, 65536, client.streamCfg.MaxBufferSize) assert.Equal(t, 100, client.streamCfg.ChannelBufferSize) } func TestDefaultStreamConfig(t *testing.T) { cfg := DefaultStreamConfig() assert.Equal(t, 4096, cfg.InitialBufferSize) assert.Equal(t, 65536, cfg.MaxBufferSize) assert.Equal(t, 100, cfg.ChannelBufferSize) } func TestClient_SendRequest_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "POST", r.Method) assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) resp := openai.ChatCompletionResponse{ ID: "chatcmpl-123", Choices: []openai.Choice{ {Index: 0, Message: &openai.Message{Role: "assistant", Content: "Hello!"}}, }, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) })) defer server.Close() client := NewClient() req := &openai.ChatCompletionRequest{ Model: "gpt-4", Messages: []openai.Message{{Role: "user", Content: "Hi"}}, } result, err := client.SendRequest(context.Background(), req, "test-key", server.URL) require.NoError(t, err) assert.Equal(t, "chatcmpl-123", result.ID) } func TestClient_SendRequest_ErrorResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) json.NewEncoder(w).Encode(openai.ErrorResponse{ Error: openai.ErrorDetail{Message: "Invalid API key"}, }) })) defer server.Close() client := NewClient() req := &openai.ChatCompletionRequest{ Model: "gpt-4", Messages: []openai.Message{{Role: "user", Content: "Hi"}}, } _, err := client.SendRequest(context.Background(), req, "bad-key", server.URL) assert.Error(t, err) assert.Contains(t, err.Error(), "Invalid API key") } func TestClient_SendRequest_ConnectionError(t *testing.T) { client := NewClient() req := &openai.ChatCompletionRequest{ Model: "gpt-4", Messages: []openai.Message{{Role: "user", Content: "Hi"}}, } _, err := client.SendRequest(context.Background(), req, "key", "http://localhost:1") assert.Error(t, err) } func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) { // 使用一个慢服务器确保客户端有时间读取 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) })) defer server.Close() client := NewClient() req := &openai.ChatCompletionRequest{ Model: "gpt-4", Messages: []openai.Message{{Role: "user", Content: "Hi"}}, } eventChan, err := client.SendStreamRequest(context.Background(), req, "test-key", server.URL) require.NoError(t, err) require.NotNil(t, eventChan) // 读取直到 channel 关闭(服务器关闭后应产生 EOF) for range eventChan { // 消费所有事件 } // channel 应已关闭(不阻塞即通过) } func TestClient_SendStreamRequest_ErrorResponse(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() client := NewClient() req := &openai.ChatCompletionRequest{ Model: "gpt-4", Messages: []openai.Message{{Role: "user", Content: "Hi"}}, } _, err := client.SendStreamRequest(context.Background(), req, "key", server.URL) assert.Error(t, err) } func TestIsNetworkError(t *testing.T) { tests := []struct { input string want bool }{ {"connection reset by peer", true}, {"broken pipe", true}, {"network is unreachable", true}, {"timeout waiting for response", true}, {"unexpected EOF", true}, {"normal error", false}, {"", false}, } for _, tt := range tests { err := fmt.Errorf("%s", tt.input) //nolint:govet assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input) } }