package logger import ( "context" "errors" "fmt" "regexp" "strings" "time" "go.uber.org/zap" "go.uber.org/zap/zapcore" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) type GormLogger struct { logger *zap.Logger level zapcore.Level } func NewGormLogger(logger *zap.Logger) *GormLogger { return &GormLogger{ logger: logger.Named("database"), level: zapcore.DebugLevel, } } func (l *GormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface { newLogger := &GormLogger{ logger: l.logger, level: l.gormLevelToZap(level), } return newLogger } func (l *GormLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.level > zapcore.DebugLevel { return } l.log(ctx, zapcore.DebugLevel, msg, data...) } func (l *GormLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.level > zapcore.WarnLevel { return } l.log(ctx, zapcore.WarnLevel, msg, data...) } func (l *GormLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.level > zapcore.ErrorLevel { return } l.log(ctx, zapcore.ErrorLevel, msg, data...) } func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.level > zapcore.DebugLevel { return } elapsed := time.Since(begin) sql, rows := fc() fields := []zap.Field{ zap.String("sql", l.formatSQL(sql)), zap.Int64("rows", rows), zap.Duration("latency", elapsed), } if requestIDField := RequestIDFromContext(ctx); requestIDField != zap.Skip() { fields = append([]zap.Field{requestIDField}, fields...) } if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { fields = append(fields, zap.Error(err)) l.logger.Error("SQL执行错误", fields...) return } l.logger.Debug("SQL查询", fields...) } func (l *GormLogger) log(ctx context.Context, level zapcore.Level, msg string, data ...interface{}) { fields := make([]zap.Field, 0, len(data)/2+1) if requestIDField := RequestIDFromContext(ctx); requestIDField != zap.Skip() { fields = append(fields, requestIDField) } for i := 0; i < len(data); i += 2 { if i+1 < len(data) { key, ok := data[i].(string) if !ok { continue } fields = append(fields, zap.Any(key, data[i+1])) } } switch level { case zapcore.DebugLevel: l.logger.Debug(fmt.Sprintf(msg, data...), fields...) case zapcore.WarnLevel: l.logger.Warn(fmt.Sprintf(msg, data...), fields...) case zapcore.ErrorLevel: l.logger.Error(fmt.Sprintf(msg, data...), fields...) } } func (l *GormLogger) gormLevelToZap(level gormlogger.LogLevel) zapcore.Level { switch level { case gormlogger.Silent: return zapcore.PanicLevel case gormlogger.Error: return zapcore.ErrorLevel case gormlogger.Warn: return zapcore.WarnLevel case gormlogger.Info: return zapcore.DebugLevel default: return zapcore.DebugLevel } } func (l *GormLogger) formatSQL(sql string) string { re := regexp.MustCompile(`\s+`) return strings.TrimSpace(re.ReplaceAllString(sql, " ")) }