188 lines
4.5 KiB
Go
188 lines
4.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"go.uber.org/zap"
|
|
"go.uber.org/zap/zapcore"
|
|
"go.uber.org/zap/zaptest/observer"
|
|
)
|
|
|
|
func init() {
|
|
gin.SetMode(gin.TestMode)
|
|
}
|
|
|
|
func TestRequestID_GeneratesUUID(t *testing.T) {
|
|
r := gin.New()
|
|
r.Use(RequestID())
|
|
r.GET("/test", func(c *gin.Context) {
|
|
id, exists := c.Get(RequestIDKey)
|
|
assert.True(t, exists)
|
|
assert.NotEmpty(t, id)
|
|
c.Status(200)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
assert.NotEmpty(t, w.Header().Get("X-Request-ID"))
|
|
}
|
|
|
|
func TestRequestID_UsesExistingHeader(t *testing.T) {
|
|
r := gin.New()
|
|
r.Use(RequestID())
|
|
r.GET("/test", func(c *gin.Context) {
|
|
id, _ := c.Get(RequestIDKey)
|
|
assert.Equal(t, "existing-id-123", id)
|
|
c.Status(200)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("X-Request-ID", "existing-id-123")
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
assert.Equal(t, "existing-id-123", w.Header().Get("X-Request-ID"))
|
|
}
|
|
|
|
func TestLogging(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
|
|
r := gin.New()
|
|
r.Use(Logging(logger))
|
|
r.GET("/test", func(c *gin.Context) {
|
|
c.Status(200)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/test?key=value", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
}
|
|
|
|
func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) {
|
|
core, logs := observer.New(zapcore.InfoLevel)
|
|
logger := zap.New(core)
|
|
|
|
w := serveLoggingRequest(logger)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
assert.Empty(t, logs.FilterMessage("请求开始").All())
|
|
assert.Empty(t, logs.FilterMessage("请求结束").All())
|
|
}
|
|
|
|
func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) {
|
|
core, logs := observer.New(zapcore.DebugLevel)
|
|
logger := zap.New(core)
|
|
|
|
w := serveLoggingRequest(logger)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
startLogs := logs.FilterMessage("请求开始").All()
|
|
endLogs := logs.FilterMessage("请求结束").All()
|
|
if assert.Len(t, startLogs, 1) {
|
|
fields := startLogs[0].ContextMap()
|
|
assert.Equal(t, "GET", fields["method"])
|
|
assert.Equal(t, "/test", fields["path"])
|
|
assert.Equal(t, "key=value", fields["query"])
|
|
assert.Equal(t, "existing-id-123", fields["request_id"])
|
|
assert.NotEmpty(t, fields["client_ip"])
|
|
}
|
|
if assert.Len(t, endLogs, 1) {
|
|
fields := endLogs[0].ContextMap()
|
|
assert.Equal(t, int64(200), fields["status"])
|
|
assert.Equal(t, "GET", fields["method"])
|
|
assert.Equal(t, "/test", fields["path"])
|
|
assert.Equal(t, int64(2), fields["body_size"])
|
|
assert.Equal(t, "existing-id-123", fields["request_id"])
|
|
assert.Contains(t, fields, "latency")
|
|
}
|
|
}
|
|
|
|
func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder {
|
|
r := gin.New()
|
|
r.Use(RequestID())
|
|
r.Use(Logging(logger))
|
|
r.GET("/test", func(c *gin.Context) {
|
|
c.String(200, "ok")
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/test?key=value", nil)
|
|
req.Header.Set("X-Request-ID", "existing-id-123")
|
|
r.ServeHTTP(w, req)
|
|
|
|
return w
|
|
}
|
|
|
|
func TestRecovery_NoPanic(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
|
|
r := gin.New()
|
|
r.Use(Recovery(logger))
|
|
r.GET("/test", func(c *gin.Context) {
|
|
c.Status(200)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
}
|
|
|
|
func TestRecovery_WithPanic(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
|
|
r := gin.New()
|
|
r.Use(Recovery(logger))
|
|
r.GET("/test", func(c *gin.Context) {
|
|
panic("test panic")
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 500, w.Code)
|
|
}
|
|
|
|
func TestCORS_NormalRequest(t *testing.T) {
|
|
r := gin.New()
|
|
r.Use(CORS())
|
|
r.GET("/test", func(c *gin.Context) {
|
|
c.Status(200)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 200, w.Code)
|
|
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
|
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET")
|
|
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "POST")
|
|
}
|
|
|
|
func TestCORS_PreflightRequest(t *testing.T) {
|
|
r := gin.New()
|
|
r.Use(CORS())
|
|
r.OPTIONS("/test", func(c *gin.Context) {
|
|
c.Status(200)
|
|
})
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("OPTIONS", "/test", nil)
|
|
r.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, 204, w.Code)
|
|
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|