feat: 实现分层架构,包含 domain、service、repository 和 pkg 层
- 新增 domain 层:model、provider、route、stats 实体 - 新增 service 层:models、providers、routing、stats 业务逻辑 - 新增 repository 层:models、providers、stats 数据访问 - 新增 pkg 工具包:errors、logger、validator - 新增中间件:CORS、logging、recovery、request ID - 新增数据库迁移:初始 schema 和索引 - 新增单元测试和集成测试 - 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage - 移除 config 子包和 model_router(已迁移至分层架构)
This commit is contained in:
74
backend/pkg/errors/errors.go
Normal file
74
backend/pkg/errors/errors.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// AppError 结构化应用错误
|
||||
type AppError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
HTTPStatus int `json:"-"`
|
||||
Cause error `json:"-"`
|
||||
Context map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// Error implements error interface
|
||||
func (e *AppError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %s (%v)", e.Code, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error
|
||||
func (e *AppError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// NewAppError creates a new AppError
|
||||
func NewAppError(code, message string, httpStatus int) *AppError {
|
||||
return &AppError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: httpStatus,
|
||||
}
|
||||
}
|
||||
|
||||
// Predefined errors
|
||||
var (
|
||||
ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound)
|
||||
ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound)
|
||||
ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound)
|
||||
ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound)
|
||||
ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest)
|
||||
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
|
||||
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
|
||||
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
|
||||
)
|
||||
|
||||
// AsAppError 尝试将 error 转换为 *AppError
|
||||
func AsAppError(err error) (*AppError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var appErr *AppError
|
||||
if ok := is(err, &appErr); ok {
|
||||
return appErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func is(err error, target interface{}) bool {
|
||||
// 简单的类型断言
|
||||
if e, ok := err.(*AppError); ok {
|
||||
// 直接赋值
|
||||
switch t := target.(type) {
|
||||
case **AppError:
|
||||
*t = e
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
125
backend/pkg/errors/errors_test.go
Normal file
125
backend/pkg/errors/errors_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAppError(t *testing.T) {
|
||||
err := NewAppError("test_code", "测试消息", http.StatusBadRequest)
|
||||
assert.Equal(t, "test_code", err.Code)
|
||||
assert.Equal(t, "测试消息", err.Message)
|
||||
assert.Equal(t, http.StatusBadRequest, err.HTTPStatus)
|
||||
assert.Nil(t, err.Cause)
|
||||
assert.Nil(t, err.Context)
|
||||
}
|
||||
|
||||
func TestAppError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *AppError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "无原因错误",
|
||||
err: NewAppError("code1", "消息1", 400),
|
||||
expected: "code1: 消息1",
|
||||
},
|
||||
{
|
||||
name: "带原因错误",
|
||||
err: Wrap(NewAppError("code2", "消息2", 500), errors.New("原始错误")),
|
||||
expected: "code2: 消息2 (原始错误)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppError_Unwrap(t *testing.T) {
|
||||
cause := errors.New("原始错误")
|
||||
err := Wrap(ErrInternal, cause)
|
||||
assert.Equal(t, cause, err.Unwrap())
|
||||
}
|
||||
|
||||
func TestWrap(t *testing.T) {
|
||||
cause := errors.New("网络超时")
|
||||
wrapped := Wrap(ErrInternal, cause)
|
||||
assert.Equal(t, "internal_error", wrapped.Code)
|
||||
assert.Equal(t, "内部错误", wrapped.Message)
|
||||
assert.Equal(t, http.StatusInternalServerError, wrapped.HTTPStatus)
|
||||
assert.Equal(t, cause, wrapped.Cause)
|
||||
}
|
||||
|
||||
func TestWithContext(t *testing.T) {
|
||||
err := WithContext(ErrModelNotFound, "model", "gpt-4")
|
||||
assert.Equal(t, "model_not_found", err.Code)
|
||||
assert.NotNil(t, err.Context)
|
||||
assert.Equal(t, "gpt-4", err.Context["model"])
|
||||
|
||||
// 测试链式添加上下文
|
||||
err2 := WithContext(err, "provider", "openai")
|
||||
assert.Equal(t, "gpt-4", err2.Context["model"])
|
||||
assert.Equal(t, "openai", err2.Context["provider"])
|
||||
}
|
||||
|
||||
func TestWithMessage(t *testing.T) {
|
||||
err := WithMessage(ErrInvalidRequest, "自定义错误消息")
|
||||
assert.Equal(t, "invalid_request", err.Code)
|
||||
assert.Equal(t, "自定义错误消息", err.Message)
|
||||
assert.Equal(t, http.StatusBadRequest, err.HTTPStatus)
|
||||
}
|
||||
|
||||
func TestPredefinedErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *AppError
|
||||
code string
|
||||
httpStatus int
|
||||
}{
|
||||
{"ErrModelNotFound", ErrModelNotFound, "model_not_found", http.StatusNotFound},
|
||||
{"ErrModelDisabled", ErrModelDisabled, "model_disabled", http.StatusNotFound},
|
||||
{"ErrProviderNotFound", ErrProviderNotFound, "provider_not_found", http.StatusNotFound},
|
||||
{"ErrProviderDisabled", ErrProviderDisabled, "provider_disabled", http.StatusNotFound},
|
||||
{"ErrInvalidRequest", ErrInvalidRequest, "invalid_request", http.StatusBadRequest},
|
||||
{"ErrInternal", ErrInternal, "internal_error", http.StatusInternalServerError},
|
||||
{"ErrDatabaseNotInit", ErrDatabaseNotInit, "database_not_initialized", http.StatusInternalServerError},
|
||||
{"ErrConflict", ErrConflict, "conflict", http.StatusConflict},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.code, tt.err.Code)
|
||||
assert.Equal(t, tt.httpStatus, tt.err.HTTPStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsAppError(t *testing.T) {
|
||||
t.Run("nil输入", func(t *testing.T) {
|
||||
_, ok := AsAppError(nil)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("AppError类型", func(t *testing.T) {
|
||||
appErr, ok := AsAppError(ErrModelNotFound)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, ErrModelNotFound, appErr)
|
||||
})
|
||||
|
||||
t.Run("Wrapped AppError", func(t *testing.T) {
|
||||
wrapped := Wrap(ErrInternal, errors.New("cause"))
|
||||
appErr, ok := AsAppError(wrapped)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "internal_error", appErr.Code)
|
||||
})
|
||||
|
||||
t.Run("非AppError类型", func(t *testing.T) {
|
||||
_, ok := AsAppError(errors.New("普通错误"))
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
42
backend/pkg/errors/wrap.go
Normal file
42
backend/pkg/errors/wrap.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package errors
|
||||
|
||||
// Wrap wraps an error with cause
|
||||
func Wrap(err *AppError, cause error) *AppError {
|
||||
return &AppError{
|
||||
Code: err.Code,
|
||||
Message: err.Message,
|
||||
HTTPStatus: err.HTTPStatus,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext adds context to an AppError
|
||||
func WithContext(err *AppError, key string, value interface{}) *AppError {
|
||||
newErr := &AppError{
|
||||
Code: err.Code,
|
||||
Message: err.Message,
|
||||
HTTPStatus: err.HTTPStatus,
|
||||
Cause: err.Cause,
|
||||
}
|
||||
if err.Context != nil {
|
||||
newErr.Context = make(map[string]interface{})
|
||||
for k, v := range err.Context {
|
||||
newErr.Context[k] = v
|
||||
}
|
||||
} else {
|
||||
newErr.Context = make(map[string]interface{})
|
||||
}
|
||||
newErr.Context[key] = value
|
||||
return newErr
|
||||
}
|
||||
|
||||
// WithMessage creates a new AppError with a custom message
|
||||
func WithMessage(err *AppError, message string) *AppError {
|
||||
return &AppError{
|
||||
Code: err.Code,
|
||||
Message: message,
|
||||
HTTPStatus: err.HTTPStatus,
|
||||
Cause: err.Cause,
|
||||
Context: err.Context,
|
||||
}
|
||||
}
|
||||
17
backend/pkg/logger/context.go
Normal file
17
backend/pkg/logger/context.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package logger
|
||||
|
||||
import "go.uber.org/zap"
|
||||
|
||||
// WithRequestID 向 logger 添加 request_id 字段
|
||||
func WithRequestID(logger *zap.Logger, requestID string) *zap.Logger {
|
||||
return logger.With(zap.String("request_id", requestID))
|
||||
}
|
||||
|
||||
// WithContext 向 logger 添加多个自定义字段
|
||||
func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for k, v := range fields {
|
||||
zapFields = append(zapFields, zap.Any(k, v))
|
||||
}
|
||||
return logger.With(zapFields...)
|
||||
}
|
||||
109
backend/pkg/logger/logger.go
Normal file
109
backend/pkg/logger/logger.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Config 日志配置
|
||||
type Config struct {
|
||||
Level string // 日志级别: debug, info, warn, error
|
||||
Path string // 日志文件目录,为空则仅输出到 stdout
|
||||
MaxSize int // 单个日志文件最大尺寸 (MB)
|
||||
MaxBackups int // 保留的旧日志文件最大数量
|
||||
MaxAge int // 保留旧日志文件的最大天数
|
||||
Compress bool // 是否压缩旧日志文件
|
||||
}
|
||||
|
||||
// New 根据配置创建 zap.Logger
|
||||
// 如果 Path 为空,仅输出到 stdout;
|
||||
// 如果 Path 已设置,同时输出到 stdout 和文件(文件使用 JSON 格式,stdout 使用 console 格式)
|
||||
func New(cfg Config) (*zap.Logger, error) {
|
||||
level, err := parseLevel(cfg.Level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// stdout encoder — console 格式
|
||||
stdoutEncoder := zapcore.NewConsoleEncoder(zapcore.EncoderConfig{
|
||||
TimeKey: "ts",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
FunctionKey: zapcore.OmitKey,
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.CapitalColorLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
})
|
||||
|
||||
stdoutCore := zapcore.NewCore(
|
||||
stdoutEncoder,
|
||||
zapcore.AddSync(os.Stdout),
|
||||
level,
|
||||
)
|
||||
|
||||
// 仅 stdout 模式
|
||||
if cfg.Path == "" {
|
||||
return zap.New(stdoutCore, zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel)), nil
|
||||
}
|
||||
|
||||
// 文件 encoder — JSON 格式
|
||||
fileEncoder := zapcore.NewJSONEncoder(zapcore.EncoderConfig{
|
||||
TimeKey: "ts",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
FunctionKey: zapcore.OmitKey,
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
})
|
||||
|
||||
rotateWriter := newRotateWriter(cfg)
|
||||
fileCore := zapcore.NewCore(
|
||||
fileEncoder,
|
||||
zapcore.AddSync(rotateWriter),
|
||||
level,
|
||||
)
|
||||
|
||||
core := zapcore.NewTee(stdoutCore, fileCore)
|
||||
return zap.New(core, zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel)), nil
|
||||
}
|
||||
|
||||
// parseLevel 将字符串解析为 zapcore.Level
|
||||
func parseLevel(s string) (zapcore.Level, error) {
|
||||
switch s {
|
||||
case "debug":
|
||||
return zapcore.DebugLevel, nil
|
||||
case "info":
|
||||
return zapcore.InfoLevel, nil
|
||||
case "warn":
|
||||
return zapcore.WarnLevel, nil
|
||||
case "error":
|
||||
return zapcore.ErrorLevel, nil
|
||||
default:
|
||||
return zapcore.InfoLevel, nil
|
||||
}
|
||||
}
|
||||
|
||||
// logFileName 生成当日日志文件名: nex-YYYY-MM-DD.log
|
||||
func logFileName() string {
|
||||
return "nex-" + time.Now().Format("2006-01-02") + ".log"
|
||||
}
|
||||
|
||||
// logFilePath 拼接完整日志文件路径
|
||||
func logFilePath(dir string) string {
|
||||
return filepath.Join(dir, logFileName())
|
||||
}
|
||||
138
backend/pkg/logger/logger_test.go
Normal file
138
backend/pkg/logger/logger_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNew_StdoutOnly(t *testing.T) {
|
||||
logger, err := New(Config{Level: "info"})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, logger)
|
||||
assert.NoError(t, logger.Sync())
|
||||
}
|
||||
|
||||
func TestNew_WithFileOutput(t *testing.T) {
|
||||
dir := filepath.Join(os.TempDir(), "nex-logger-test")
|
||||
os.MkdirAll(dir, 0755)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
logger, err := New(Config{
|
||||
Level: "debug",
|
||||
Path: dir,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, logger)
|
||||
|
||||
logger.Info("test log message")
|
||||
_ = logger.Sync()
|
||||
|
||||
// 验证日志文件已创建
|
||||
files, err := os.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, files, "日志目录应包含文件")
|
||||
}
|
||||
|
||||
func TestNew_AllLevels(t *testing.T) {
|
||||
levels := []string{"debug", "info", "warn", "error"}
|
||||
for _, level := range levels {
|
||||
logger, err := New(Config{Level: level})
|
||||
assert.NoError(t, err, "级别 %s 应有效", level)
|
||||
assert.NotNil(t, logger)
|
||||
assert.NoError(t, logger.Sync())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_EmptyLevel(t *testing.T) {
|
||||
// 空级别应默认为 info
|
||||
logger, err := New(Config{Level: ""})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, logger)
|
||||
assert.NoError(t, logger.Sync())
|
||||
}
|
||||
|
||||
func TestNew_InvalidPath(t *testing.T) {
|
||||
// 不可写的路径
|
||||
logger, err := New(Config{
|
||||
Level: "info",
|
||||
Path: "/nonexistent/deeply/nested/path/logs",
|
||||
})
|
||||
// 应能创建 logger(错误在写入时发生)
|
||||
// 实际上 lumberjack 会尝试创建目录
|
||||
_ = logger
|
||||
_ = err
|
||||
}
|
||||
|
||||
func TestParseLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
valid bool
|
||||
}{
|
||||
{"debug", true},
|
||||
{"info", true},
|
||||
{"warn", true},
|
||||
{"error", true},
|
||||
{"", true}, // 默认为 info
|
||||
{"invalid", true}, // 默认为 info
|
||||
}
|
||||
for _, tt := range tests {
|
||||
_, err := parseLevel(tt.input)
|
||||
assert.NoError(t, err, "parseLevel(%q) 不应报错", tt.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogFilePath(t *testing.T) {
|
||||
result := logFilePath(filepath.Join("var", "log"))
|
||||
assert.Contains(t, result, "nex-")
|
||||
assert.Contains(t, result, ".log")
|
||||
}
|
||||
|
||||
func TestLogFileName(t *testing.T) {
|
||||
name := logFileName()
|
||||
assert.Contains(t, name, "nex-")
|
||||
assert.Contains(t, name, ".log")
|
||||
assert.Len(t, name, len("nex-2006-01-02.log"))
|
||||
}
|
||||
|
||||
func TestNewRotateWriter_Defaults(t *testing.T) {
|
||||
cfg := Config{
|
||||
Path: t.TempDir(),
|
||||
MaxSize: 0,
|
||||
MaxAge: 0,
|
||||
Compress: true,
|
||||
}
|
||||
writer := newRotateWriter(cfg)
|
||||
require.NotNil(t, writer)
|
||||
assert.Equal(t, 100, writer.MaxSize)
|
||||
assert.Equal(t, 10, writer.MaxBackups)
|
||||
assert.Equal(t, 30, writer.MaxAge)
|
||||
}
|
||||
|
||||
func TestWithRequestID(t *testing.T) {
|
||||
logger, err := New(Config{Level: "info"})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLogger := WithRequestID(logger, "test-request-123")
|
||||
assert.NotNil(t, contextLogger)
|
||||
assert.IsType(t, &zap.Logger{}, contextLogger)
|
||||
}
|
||||
|
||||
func TestWithContext(t *testing.T) {
|
||||
logger, err := New(Config{Level: "info"})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLogger := WithContext(logger, map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
})
|
||||
assert.NotNil(t, contextLogger)
|
||||
}
|
||||
30
backend/pkg/logger/rotate.go
Normal file
30
backend/pkg/logger/rotate.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package logger
|
||||
|
||||
import "gopkg.in/lumberjack.v2"
|
||||
|
||||
// newRotateWriter 根据配置创建 lumberjack.Logger 作为日志轮转写入器
|
||||
// 日志文件位于 cfg.Path 目录下,文件名格式为 nex-YYYY-MM-DD.log
|
||||
func newRotateWriter(cfg Config) *lumberjack.Logger {
|
||||
maxSize := cfg.MaxSize
|
||||
if maxSize <= 0 {
|
||||
maxSize = 100
|
||||
}
|
||||
|
||||
maxBackups := cfg.MaxBackups
|
||||
if maxBackups <= 0 {
|
||||
maxBackups = 10
|
||||
}
|
||||
|
||||
maxAge := cfg.MaxAge
|
||||
if maxAge <= 0 {
|
||||
maxAge = 30
|
||||
}
|
||||
|
||||
return &lumberjack.Logger{
|
||||
Filename: logFilePath(cfg.Path),
|
||||
MaxSize: maxSize, // MB
|
||||
MaxBackups: maxBackups,
|
||||
MaxAge: maxAge, // days
|
||||
Compress: cfg.Compress,
|
||||
}
|
||||
}
|
||||
22
backend/pkg/validator/validator.go
Normal file
22
backend/pkg/validator/validator.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// validate 全局验证器实例
|
||||
var validate *validator.Validate
|
||||
|
||||
func init() {
|
||||
validate = validator.New(validator.WithRequiredStructEnabled())
|
||||
}
|
||||
|
||||
// Get 返回全局验证器实例
|
||||
func Get() *validator.Validate {
|
||||
return validate
|
||||
}
|
||||
|
||||
// Validate 验证结构体
|
||||
func Validate(s interface{}) error {
|
||||
return validate.Struct(s)
|
||||
}
|
||||
45
backend/pkg/validator/validator_test.go
Normal file
45
backend/pkg/validator/validator_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestStruct struct {
|
||||
Name string `validate:"required"`
|
||||
Email string `validate:"required,email"`
|
||||
Age int `validate:"min=0,max=150"`
|
||||
}
|
||||
|
||||
func TestValidate_ValidStruct(t *testing.T) {
|
||||
s := TestStruct{Name: "John", Email: "john@example.com", Age: 25}
|
||||
err := Validate(s)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidate_MissingRequired(t *testing.T) {
|
||||
s := TestStruct{Email: "john@example.com", Age: 25}
|
||||
err := Validate(s)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidate_InvalidEmail(t *testing.T) {
|
||||
s := TestStruct{Name: "John", Email: "not-an-email", Age: 25}
|
||||
err := Validate(s)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidate_AgeOutOfRange(t *testing.T) {
|
||||
s := TestStruct{Name: "John", Email: "john@example.com", Age: 200}
|
||||
err := Validate(s)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGet_ReturnsInstance(t *testing.T) {
|
||||
v := Get()
|
||||
assert.NotNil(t, v)
|
||||
// 多次调用应返回相同实例
|
||||
v2 := Get()
|
||||
assert.Equal(t, v, v2)
|
||||
}
|
||||
Reference in New Issue
Block a user