package handler import ( "bytes" "context" "encoding/json" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "nex/backend/internal/domain" "nex/backend/internal/protocol/openai" "nex/backend/internal/provider" appErrors "nex/backend/pkg/errors" ) func init() { gin.SetMode(gin.TestMode) } // ============ Mock 实现 ============ type mockRoutingService struct { result *domain.RouteResult err error } func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) { return m.result, m.err } type mockStatsService struct { err error stats []domain.UsageStats aggrResult []map[string]interface{} } func (m *mockStatsService) Record(providerID, modelName string) error { return m.err } func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { return m.stats, nil } func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} { return m.aggrResult } type mockProviderService struct { provider *domain.Provider providers []domain.Provider err error } func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err } func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) { return m.provider, m.err } func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err } func (m *mockProviderService) Update(id string, updates map[string]interface{}) error { return m.err } func (m *mockProviderService) Delete(id string) error { return m.err } type mockModelService struct { model *domain.Model models []domain.Model err error } func (m *mockModelService) Create(model *domain.Model) error { return m.err } func (m *mockModelService) Get(id string) (*domain.Model, error) { return m.model, m.err } func (m *mockModelService) List(providerID string) ([]domain.Model, error) { return m.models, m.err } func (m *mockModelService) Update(id string, updates map[string]interface{}) error { return m.err } func (m *mockModelService) Delete(id string) error { return m.err } type mockProviderClient struct { resp *openai.ChatCompletionResponse eventChan chan provider.StreamEvent err error } func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) { return m.resp, m.err } func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) { return m.eventChan, m.err } // ============ OpenAI Handler 测试 ============ func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) { h := NewOpenAIHandler(nil, nil, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid"))) h.HandleChatCompletions(c) assert.Equal(t, 400, w.Code) } func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) { h := NewOpenAIHandler(nil, nil, nil) // 缺少 model 字段 body, _ := json.Marshal(map[string]interface{}{ "messages": []map[string]string{{"role": "user", "content": "hi"}}, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.HandleChatCompletions(c) assert.Equal(t, 400, w.Code) } func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) { routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound} h := NewOpenAIHandler(nil, routingSvc, nil) body, _ := json.Marshal(map[string]interface{}{ "model": "nonexistent", "messages": []map[string]string{{"role": "user", "content": "hi"}}, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.HandleChatCompletions(c) assert.Equal(t, 404, w.Code) } // ============ Provider Handler 测试 ============ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { h := NewProviderHandler(&mockProviderService{}) body, _ := json.Marshal(map[string]string{"id": "p1"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.CreateProvider(c) assert.Equal(t, 400, w.Code) } func TestProviderHandler_ListProviders(t *testing.T) { h := NewProviderHandler(&mockProviderService{ providers: []domain.Provider{ {ID: "p1", Name: "P1"}, {ID: "p2", Name: "P2"}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/providers", nil) h.ListProviders(c) assert.Equal(t, 200, w.Code) var result []domain.Provider json.Unmarshal(w.Body.Bytes(), &result) assert.Len(t, result, 2) } func TestProviderHandler_GetProvider(t *testing.T) { h := NewProviderHandler(&mockProviderService{ provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"}, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Params = gin.Params{{Key: "id", Value: "p1"}} c.Request = httptest.NewRequest("GET", "/api/providers/p1", nil) h.GetProvider(c) assert.Equal(t, 200, w.Code) } // ============ Model Handler 测试 ============ func TestModelHandler_CreateModel_MissingFields(t *testing.T) { h := NewModelHandler(&mockModelService{}) body, _ := json.Marshal(map[string]string{"id": "m1"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.CreateModel(c) assert.Equal(t, 400, w.Code) } func TestModelHandler_ListModels(t *testing.T) { h := NewModelHandler(&mockModelService{ models: []domain.Model{ {ID: "m1", ModelName: "gpt-4"}, {ID: "m2", ModelName: "gpt-3.5"}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/models", nil) h.ListModels(c) assert.Equal(t, 200, w.Code) } // ============ Stats Handler 测试 ============ func TestStatsHandler_GetStats(t *testing.T) { h := NewStatsHandler(&mockStatsService{ stats: []domain.UsageStats{ {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/stats", nil) h.GetStats(c) assert.Equal(t, 200, w.Code) } func TestStatsHandler_GetStats_InvalidDate(t *testing.T) { h := NewStatsHandler(&mockStatsService{}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/stats?start_date=invalid", nil) h.GetStats(c) assert.Equal(t, 400, w.Code) } func TestStatsHandler_AggregateStats(t *testing.T) { h := NewStatsHandler(&mockStatsService{ stats: []domain.UsageStats{ {ProviderID: "p1", RequestCount: 10}, }, aggrResult: []map[string]interface{}{ {"provider_id": "p1", "request_count": 10}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil) h.AggregateStats(c) assert.Equal(t, 200, w.Code) } // ============ writeError 测试 ============ func TestWriteError(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/", nil) writeError(c, appErrors.ErrModelNotFound) assert.Equal(t, 404, w.Code) } func TestFormatValidationErrors(t *testing.T) { errs := map[string]string{ "model": "模型名称不能为空", "messages": "消息列表不能为空", } result := formatValidationErrors(errs) require.Contains(t, result, "请求验证失败") require.Contains(t, result, "model") require.Contains(t, result, "messages") }