package handler import ( "bytes" "context" "encoding/json" "fmt" "net/http/httptest" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "nex/backend/internal/domain" "nex/backend/internal/provider" appErrors "nex/backend/pkg/errors" ) func init() { gin.SetMode(gin.TestMode) } // ============ Mock 实现 ============ type mockRoutingService struct { result *domain.RouteResult err error } func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) { return m.result, m.err } type mockStatsService struct { err error stats []domain.UsageStats aggrResult []map[string]interface{} } func (m *mockStatsService) Record(providerID, modelName string) error { return m.err } func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { return m.stats, nil } func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} { return m.aggrResult } type mockProviderService struct { provider *domain.Provider providers []domain.Provider err error } func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) { return nil, nil } func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) { return nil, nil } func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err } func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) { return m.provider, m.err } func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err } func (m *mockProviderService) Update(id string, updates map[string]interface{}) error { return m.err } func (m *mockProviderService) Delete(id string) error { return m.err } type mockModelService struct { model *domain.Model models []domain.Model err error } func (m *mockModelService) Create(model *domain.Model) error { if m.err == nil { model.ID = "mock-uuid-1234" } return m.err } func (m *mockModelService) Get(id string) (*domain.Model, error) { return m.model, m.err } func (m *mockModelService) List(providerID string) ([]domain.Model, error) { return m.models, m.err } func (m *mockModelService) ListEnabled() ([]domain.Model, error) { return []domain.Model{}, nil } func (m *mockModelService) Update(id string, updates map[string]interface{}) error { return m.err } func (m *mockModelService) Delete(id string) error { return m.err } type mockProviderClient struct { err error } func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) { return nil, m.err } func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) { return nil, m.err } // ============ Provider Handler 测试 ============ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { h := NewProviderHandler(&mockProviderService{}) body, _ := json.Marshal(map[string]string{"id": "p1"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.CreateProvider(c) assert.Equal(t, 400, w.Code) } func TestProviderHandler_ListProviders(t *testing.T) { h := NewProviderHandler(&mockProviderService{ providers: []domain.Provider{ {ID: "p1", Name: "P1"}, {ID: "p2", Name: "P2"}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/providers", nil) h.ListProviders(c) assert.Equal(t, 200, w.Code) var result []domain.Provider json.Unmarshal(w.Body.Bytes(), &result) assert.Len(t, result, 2) } func TestProviderHandler_GetProvider(t *testing.T) { h := NewProviderHandler(&mockProviderService{ provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"}, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Params = gin.Params{{Key: "id", Value: "p1"}} c.Request = httptest.NewRequest("GET", "/api/providers/p1", nil) h.GetProvider(c) assert.Equal(t, 200, w.Code) } // ============ Model Handler 测试 ============ func TestModelHandler_CreateModel_MissingFields(t *testing.T) { h := NewModelHandler(&mockModelService{}) body, _ := json.Marshal(map[string]string{"id": "m1"}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.CreateModel(c) assert.Equal(t, 400, w.Code) } func TestModelHandler_ListModels(t *testing.T) { h := NewModelHandler(&mockModelService{ models: []domain.Model{ {ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, {ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/models", nil) h.ListModels(c) assert.Equal(t, 200, w.Code) var result []modelResponse require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) require.Len(t, result, 2) assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID) assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID) } func TestModelHandler_GetModel_UnifiedID(t *testing.T) { h := NewModelHandler(&mockModelService{ model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Params = gin.Params{{Key: "id", Value: "m1"}} c.Request = httptest.NewRequest("GET", "/api/models/m1", nil) h.GetModel(c) assert.Equal(t, 200, w.Code) var result modelResponse require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) assert.Equal(t, "m1", result.ID) assert.Equal(t, "openai/gpt-4", result.UnifiedModelID) } func TestModelHandler_CreateModel_UnifiedID(t *testing.T) { h := NewModelHandler(&mockModelService{}) body, _ := json.Marshal(map[string]string{ "provider_id": "openai", "model_name": "gpt-4", }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.CreateModel(c) assert.Equal(t, 201, w.Code) var result modelResponse require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) assert.Equal(t, "mock-uuid-1234", result.ID) assert.Equal(t, "openai/gpt-4", result.UnifiedModelID) } func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) { h := NewModelHandler(&mockModelService{ model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, }) body, _ := json.Marshal(map[string]interface{}{"enabled": false}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Params = gin.Params{{Key: "id", Value: "m1"}} c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.UpdateModel(c) assert.Equal(t, 200, w.Code) var result modelResponse require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID) } // ============ Stats Handler 测试 ============ func TestStatsHandler_GetStats(t *testing.T) { h := NewStatsHandler(&mockStatsService{ stats: []domain.UsageStats{ {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/stats", nil) h.GetStats(c) assert.Equal(t, 200, w.Code) } func TestStatsHandler_GetStats_InvalidDate(t *testing.T) { h := NewStatsHandler(&mockStatsService{}) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/stats?start_date=invalid", nil) h.GetStats(c) assert.Equal(t, 400, w.Code) } func TestStatsHandler_AggregateStats(t *testing.T) { h := NewStatsHandler(&mockStatsService{ stats: []domain.UsageStats{ {ProviderID: "p1", RequestCount: 10}, }, aggrResult: []map[string]interface{}{ {"provider_id": "p1", "request_count": 10}, }, }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil) h.AggregateStats(c) assert.Equal(t, 200, w.Code) } // ============ writeError 测试 ============ func TestWriteError(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("GET", "/", nil) writeError(c, appErrors.ErrModelNotFound) assert.Equal(t, 404, w.Code) } func TestFormatValidationErrors(t *testing.T) { errs := map[string]string{ "model": "模型名称不能为空", "messages": "消息列表不能为空", } result := formatMapErrors(errs) require.Contains(t, result, "请求验证失败") require.Contains(t, result, "model") require.Contains(t, result, "messages") } func formatMapErrors(errs map[string]string) string { parts := make([]string, 0, len(errs)) for field, msg := range errs { parts = append(parts, fmt.Sprintf("%s: %s", field, msg)) } return "请求验证失败: " + strings.Join(parts, "; ") } // ============ 错误类型判断测试 ============ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) { h := NewProviderHandler(&mockProviderService{ err: appErrors.ErrConflict, }) body, _ := json.Marshal(map[string]string{ "id": "p1", "name": "Test", "api_key": "sk-test", "base_url": "https://test.com", }) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body)) c.Request.Header.Set("Content-Type", "application/json") h.CreateProvider(c) assert.Equal(t, 409, w.Code) }