package middleware import ( "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" ) func init() { gin.SetMode(gin.TestMode) } func TestRequestID_GeneratesUUID(t *testing.T) { r := gin.New() r.Use(RequestID()) r.GET("/test", func(c *gin.Context) { id, exists := c.Get(RequestIDKey) assert.True(t, exists) assert.NotEmpty(t, id) c.Status(200) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.NotEmpty(t, w.Header().Get("X-Request-ID")) } func TestRequestID_UsesExistingHeader(t *testing.T) { r := gin.New() r.Use(RequestID()) r.GET("/test", func(c *gin.Context) { id, _ := c.Get(RequestIDKey) assert.Equal(t, "existing-id-123", id) c.Status(200) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Request-ID", "existing-id-123") r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.Equal(t, "existing-id-123", w.Header().Get("X-Request-ID")) } func TestLogging(t *testing.T) { logger := zap.NewNop() r := gin.New() r.Use(Logging(logger)) r.GET("/test", func(c *gin.Context) { c.Status(200) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test?key=value", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) { core, logs := observer.New(zapcore.InfoLevel) logger := zap.New(core) w := serveLoggingRequest(logger) assert.Equal(t, 200, w.Code) assert.Empty(t, logs.FilterMessage("请求开始").All()) assert.Empty(t, logs.FilterMessage("请求结束").All()) } func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) { core, logs := observer.New(zapcore.DebugLevel) logger := zap.New(core) w := serveLoggingRequest(logger) assert.Equal(t, 200, w.Code) startLogs := logs.FilterMessage("请求开始").All() endLogs := logs.FilterMessage("请求结束").All() if assert.Len(t, startLogs, 1) { fields := startLogs[0].ContextMap() assert.Equal(t, "GET", fields["method"]) assert.Equal(t, "/test", fields["path"]) assert.Equal(t, "key=value", fields["query"]) assert.Equal(t, "existing-id-123", fields["request_id"]) assert.NotEmpty(t, fields["client_ip"]) } if assert.Len(t, endLogs, 1) { fields := endLogs[0].ContextMap() assert.Equal(t, int64(200), fields["status"]) assert.Equal(t, "GET", fields["method"]) assert.Equal(t, "/test", fields["path"]) assert.Equal(t, int64(2), fields["body_size"]) assert.Equal(t, "existing-id-123", fields["request_id"]) assert.Contains(t, fields, "latency") } } func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder { r := gin.New() r.Use(RequestID()) r.Use(Logging(logger)) r.GET("/test", func(c *gin.Context) { c.String(200, "ok") }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test?key=value", nil) req.Header.Set("X-Request-ID", "existing-id-123") r.ServeHTTP(w, req) return w } func TestRecovery_NoPanic(t *testing.T) { logger := zap.NewNop() r := gin.New() r.Use(Recovery(logger)) r.GET("/test", func(c *gin.Context) { c.Status(200) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestRecovery_WithPanic(t *testing.T) { logger := zap.NewNop() r := gin.New() r.Use(Recovery(logger)) r.GET("/test", func(c *gin.Context) { panic("test panic") }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) r.ServeHTTP(w, req) assert.Equal(t, 500, w.Code) } func TestCORS_NormalRequest(t *testing.T) { r := gin.New() r.Use(CORS()) r.GET("/test", func(c *gin.Context) { c.Status(200) }) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/test", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET") assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "POST") } func TestCORS_PreflightRequest(t *testing.T) { r := gin.New() r.Use(CORS()) r.OPTIONS("/test", func(c *gin.Context) { c.Status(200) }) w := httptest.NewRecorder() req := httptest.NewRequest("OPTIONS", "/test", nil) r.ServeHTTP(w, req) assert.Equal(t, 204, w.Code) assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) }