- 新增 domain 层:model、provider、route、stats 实体 - 新增 service 层:models、providers、routing、stats 业务逻辑 - 新增 repository 层:models、providers、stats 数据访问 - 新增 pkg 工具包:errors、logger、validator - 新增中间件:CORS、logging、recovery、request ID - 新增数据库迁移:初始 schema 和索引 - 新增单元测试和集成测试 - 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage - 移除 config 子包和 model_router(已迁移至分层架构)
152 lines
4.3 KiB
Go
152 lines
4.3 KiB
Go
package provider
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
|
||
"nex/backend/internal/protocol/openai"
|
||
)
|
||
|
||
func TestNewClient(t *testing.T) {
|
||
client := NewClient()
|
||
require.NotNil(t, client)
|
||
assert.NotNil(t, client.httpClient)
|
||
assert.NotNil(t, client.adapter)
|
||
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
|
||
assert.Equal(t, 65536, client.streamCfg.MaxBufferSize)
|
||
assert.Equal(t, 100, client.streamCfg.ChannelBufferSize)
|
||
}
|
||
|
||
func TestDefaultStreamConfig(t *testing.T) {
|
||
cfg := DefaultStreamConfig()
|
||
assert.Equal(t, 4096, cfg.InitialBufferSize)
|
||
assert.Equal(t, 65536, cfg.MaxBufferSize)
|
||
assert.Equal(t, 100, cfg.ChannelBufferSize)
|
||
}
|
||
|
||
func TestClient_SendRequest_Success(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
assert.Equal(t, "POST", r.Method)
|
||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
||
|
||
resp := openai.ChatCompletionResponse{
|
||
ID: "chatcmpl-123",
|
||
Choices: []openai.Choice{
|
||
{Index: 0, Message: &openai.Message{Role: "assistant", Content: "Hello!"}},
|
||
},
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(resp)
|
||
}))
|
||
defer server.Close()
|
||
|
||
client := NewClient()
|
||
req := &openai.ChatCompletionRequest{
|
||
Model: "gpt-4",
|
||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||
}
|
||
|
||
result, err := client.SendRequest(context.Background(), req, "test-key", server.URL)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||
}
|
||
|
||
func TestClient_SendRequest_ErrorResponse(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
json.NewEncoder(w).Encode(openai.ErrorResponse{
|
||
Error: openai.ErrorDetail{Message: "Invalid API key"},
|
||
})
|
||
}))
|
||
defer server.Close()
|
||
|
||
client := NewClient()
|
||
req := &openai.ChatCompletionRequest{
|
||
Model: "gpt-4",
|
||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||
}
|
||
|
||
_, err := client.SendRequest(context.Background(), req, "bad-key", server.URL)
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "Invalid API key")
|
||
}
|
||
|
||
func TestClient_SendRequest_ConnectionError(t *testing.T) {
|
||
client := NewClient()
|
||
req := &openai.ChatCompletionRequest{
|
||
Model: "gpt-4",
|
||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||
}
|
||
|
||
_, err := client.SendRequest(context.Background(), req, "key", "http://localhost:1")
|
||
assert.Error(t, err)
|
||
}
|
||
|
||
func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) {
|
||
// 使用一个慢服务器确保客户端有时间读取
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
w.WriteHeader(http.StatusOK)
|
||
}))
|
||
defer server.Close()
|
||
|
||
client := NewClient()
|
||
req := &openai.ChatCompletionRequest{
|
||
Model: "gpt-4",
|
||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||
}
|
||
|
||
eventChan, err := client.SendStreamRequest(context.Background(), req, "test-key", server.URL)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, eventChan)
|
||
|
||
// 读取直到 channel 关闭(服务器关闭后应产生 EOF)
|
||
for range eventChan {
|
||
// 消费所有事件
|
||
}
|
||
// channel 应已关闭(不阻塞即通过)
|
||
}
|
||
|
||
func TestClient_SendStreamRequest_ErrorResponse(t *testing.T) {
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
}))
|
||
defer server.Close()
|
||
|
||
client := NewClient()
|
||
req := &openai.ChatCompletionRequest{
|
||
Model: "gpt-4",
|
||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||
}
|
||
|
||
_, err := client.SendStreamRequest(context.Background(), req, "key", server.URL)
|
||
assert.Error(t, err)
|
||
}
|
||
|
||
func TestIsNetworkError(t *testing.T) {
|
||
tests := []struct {
|
||
input string
|
||
want bool
|
||
}{
|
||
{"connection reset by peer", true},
|
||
{"broken pipe", true},
|
||
{"network is unreachable", true},
|
||
{"timeout waiting for response", true},
|
||
{"unexpected EOF", true},
|
||
{"normal error", false},
|
||
{"", false},
|
||
}
|
||
for _, tt := range tests {
|
||
err := fmt.Errorf("%s", tt.input) //nolint:govet
|
||
assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input)
|
||
}
|
||
}
|