package database import ( "io/fs" "os" "path/filepath" "testing" "nex/backend/internal/config" "nex/backend/migrations" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" ) func TestInit_SQLite(t *testing.T) { dir := t.TempDir() cfg := &config.DatabaseConfig{ Driver: "sqlite", Path: filepath.Join(dir, "test.db"), MaxIdleConns: 5, MaxOpenConns: 10, ConnMaxLifetime: 0, } zapLogger := zap.NewNop() db, err := Init(cfg, zapLogger) require.NoError(t, err) require.NotNil(t, db) defer Close(db) sqlDB, err := db.DB() require.NoError(t, err) require.NotNil(t, sqlDB) } func TestClose(t *testing.T) { dir := t.TempDir() cfg := &config.DatabaseConfig{ Driver: "sqlite", Path: filepath.Join(dir, "test.db"), MaxIdleConns: 5, MaxOpenConns: 10, ConnMaxLifetime: 0, } zapLogger := zap.NewNop() db, err := Init(cfg, zapLogger) require.NoError(t, err) require.NotNil(t, db) Close(db) } func TestBuildDSN(t *testing.T) { cfg := &config.DatabaseConfig{ Driver: "mysql", Host: "db.example.com", Port: 3306, User: "nexuser", Password: "secretpass", DBName: "nexdb", } dsn := BuildDSN(cfg) assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn) } func TestBuildDSN_EmptyPassword(t *testing.T) { cfg := &config.DatabaseConfig{ Driver: "mysql", Host: "localhost", Port: 3306, User: "root", DBName: "nex", } dsn := BuildDSN(cfg) assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn) } func TestInit_SQLite_AnyCWD(t *testing.T) { dir := t.TempDir() origDir, err := os.Getwd() if err == nil { defer func() { if chdirErr := os.Chdir(origDir); chdirErr != nil { t.Logf("无法恢复工作目录: %v", chdirErr) } }() } if chdirErr := os.Chdir(dir); chdirErr != nil { t.Skipf("无法切换到临时目录: %v", chdirErr) } cfg := &config.DatabaseConfig{ Driver: "sqlite", Path: filepath.Join(dir, "test.db"), MaxIdleConns: 5, MaxOpenConns: 10, ConnMaxLifetime: 0, } zapLogger := zap.NewNop() db, err := Init(cfg, zapLogger) require.NoError(t, err) require.NotNil(t, db) defer Close(db) sqlDB, err := db.DB() require.NoError(t, err) require.NotNil(t, sqlDB) } func TestForDriverDialect_SQLite(t *testing.T) { require.NoError(t, testMigrateWithDriver(t, "sqlite")) } func TestForDriverDialect_MySQL(t *testing.T) { dialect, fsys, err := migrations.ForDriver("mysql") require.NoError(t, err) assert.Equal(t, "mysql", string(dialect)) entries, fsErr := fs.ReadDir(fsys, ".") require.NoError(t, fsErr) assert.NotEmpty(t, entries, "MySQL 迁移资源应至少包含一个文件") } func TestForDriverDialect_Invalid(t *testing.T) { dir := t.TempDir() cfg := &config.DatabaseConfig{ Driver: "postgres", Path: filepath.Join(dir, "test.db"), MaxIdleConns: 5, MaxOpenConns: 10, ConnMaxLifetime: 0, } zapLogger := zap.NewNop() _, err := Init(cfg, zapLogger) assert.Error(t, err, "非法 driver 应返回错误") assert.Contains(t, err.Error(), "不支持的数据库驱动") } func testMigrateWithDriver(t *testing.T, driver string) error { t.Helper() dir := t.TempDir() cfg := &config.DatabaseConfig{ Driver: driver, Path: filepath.Join(dir, "test.db"), MaxIdleConns: 5, MaxOpenConns: 10, ConnMaxLifetime: 0, } zapLogger := zap.NewNop() db, err := Init(cfg, zapLogger) if err != nil { return err } Close(db) return nil }