//go:build mysql package mysql import ( "database/sql" "fmt" "os" "path/filepath" "runtime" "testing" "time" "github.com/pressly/goose/v3" "github.com/stretchr/testify/require" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) type MySQLTestConfig struct { Host string Port int User string Password string Database string } func getMySQLTestConfig() *MySQLTestConfig { return &MySQLTestConfig{ Host: getEnvOrDefault("NEX_TEST_MYSQL_HOST", "localhost"), Port: getEnvOrDefaultInt("NEX_TEST_MYSQL_PORT", 13306), User: getEnvOrDefault("NEX_TEST_MYSQL_USER", "nex_test"), Password: getEnvOrDefault("NEX_TEST_MYSQL_PASSWORD", "testpass"), Database: getEnvOrDefault("NEX_TEST_MYSQL_DATABASE", "nex_test"), } } func getEnvOrDefault(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } func getEnvOrDefaultInt(key string, defaultValue int) int { if value := os.Getenv(key); value != "" { var intValue int if _, err := fmt.Sscanf(value, "%d", &intValue); err == nil { return intValue } } return defaultValue } func SkipIfMySQLUnavailable(t *testing.T) { t.Helper() cfg := getMySQLTestConfig() dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database) db, err := sql.Open("mysql", dsn) if err != nil { t.Skipf("MySQL 不可用: %v", err) } defer db.Close() if err := db.Ping(); err != nil { t.Skipf("MySQL 不可用: %v", err) } } func SetupMySQLTestDB(t *testing.T) *gorm.DB { t.Helper() SkipIfMySQLUnavailable(t) cfg := getMySQLTestConfig() dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database) db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) require.NoError(t, err, "连接 MySQL 失败") if err := runMigrations(db); err != nil { require.NoError(t, err, "运行迁移失败") } if err := cleanupTables(db); err != nil { require.NoError(t, err, "清理表数据失败") } sqlDB, err := db.DB() require.NoError(t, err) sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(100) sqlDB.SetConnMaxLifetime(time.Hour) t.Cleanup(func() { time.Sleep(50 * time.Millisecond) sqlDB, err := db.DB() if err == nil { sqlDB.Close() } }) return db } func cleanupTables(db *gorm.DB) error { if err := db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { return err } if err := db.Exec("TRUNCATE TABLE usage_stats").Error; err != nil { return err } if err := db.Exec("TRUNCATE TABLE models").Error; err != nil { return err } if err := db.Exec("TRUNCATE TABLE providers").Error; err != nil { return err } if err := db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil { return err } return nil } func runMigrations(db *gorm.DB) error { sqlDB, err := db.DB() if err != nil { return err } migrationsDir := getMigrationsDir() if _, err := os.Stat(migrationsDir); os.IsNotExist(err) { return fmt.Errorf("迁移目录不存在: %s", migrationsDir) } goose.SetDialect("mysql") if err := goose.Up(sqlDB, migrationsDir); err != nil { return err } return nil } func getMigrationsDir() string { _, filename, _, ok := runtime.Caller(0) if ok { dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", "mysql") if abs, err := filepath.Abs(dir); err == nil { return abs } } return "./migrations/mysql" }