1
0

feat: 新增 MySQL 专项测试能力

- 新增 backend/tests/mysql/ 目录,包含 Docker Compose 配置和测试文件
- 新增 Makefile 命令: test-mysql, test-mysql-up, test-mysql-down, test-mysql-quick
- 使用 build tag 控制测试启用,默认不运行
- 测试覆盖: 迁移正确性、外键约束、UNIQUE 约束、并发写入
- 发现 statsRepo.Record 存在并发 bug(检查-然后-操作竞态条件)
This commit is contained in:
2026-04-23 12:25:55 +08:00
parent 5b765c8b5e
commit 5b401e29cb
7 changed files with 732 additions and 0 deletions

View File

@@ -0,0 +1,158 @@
//go:build mysql
package mysql
import (
"sync"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
"nex/backend/internal/repository"
)
func TestConcurrent_UsageStatsRecord(t *testing.T) {
db := SetupMySQLTestDB(t)
statsRepo := repository.NewStatsRepository(db)
providerID := "concurrent-test-provider"
modelName := "gpt-4"
concurrency := 10
var wg sync.WaitGroup
wg.Add(concurrency)
errChan := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
err := statsRepo.Record(providerID, modelName)
if err != nil {
errChan <- err
}
}()
}
wg.Wait()
close(errChan)
var errorCount int
uniqueErrors := make(map[string]int)
for err := range errChan {
errorCount++
uniqueErrors[err.Error()]++
}
t.Logf("并发 %d 次,错误 %d 次", concurrency, errorCount)
for errMsg, count := range uniqueErrors {
t.Logf(" 错误: %s (出现 %d 次)", errMsg, count)
}
var stats config.UsageStats
err := db.Where("provider_id = ? AND model_name = ?", providerID, modelName).
First(&stats).Error
require.NoError(t, err, "应能查到 usage_stats 记录")
successCount := concurrency - errorCount
t.Logf("成功次数: %d, 最终 request_count: %d", successCount, stats.RequestCount)
assert.Equal(t, concurrency, stats.RequestCount, "request_count 应等于并发数,无数据丢失或重复")
}
func TestConcurrent_ProviderCreate(t *testing.T) {
db := SetupMySQLTestDB(t)
providerID := "concurrent-provider-id"
concurrency := 10
var wg sync.WaitGroup
wg.Add(concurrency)
successCount := 0
var mu sync.Mutex
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
provider := config.Provider{
ID: providerID,
Name: "Concurrent Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
if err == nil {
mu.Lock()
successCount++
mu.Unlock()
}
}()
}
wg.Wait()
assert.Equal(t, 1, successCount, "仅 1 个创建应成功")
var count int64
db.Model(&config.Provider{}).Where("id = ?", providerID).Count(&count)
assert.Equal(t, int64(1), count, "最终应有 1 条记录")
}
func TestConcurrent_ModelCreate(t *testing.T) {
db := SetupMySQLTestDB(t)
provider := config.Provider{
ID: "concurrent-model-provider",
Name: "Test Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
require.NoError(t, err, "创建 provider 应成功")
modelName := "gpt-4-concurrent"
concurrency := 10
var wg sync.WaitGroup
wg.Add(concurrency)
successCount := 0
var mu sync.Mutex
for i := 0; i < concurrency; i++ {
go func(idx int) {
defer wg.Done()
model := config.Model{
ID: uuid.New().String(),
ProviderID: provider.ID,
ModelName: modelName,
Enabled: true,
}
err := db.Create(&model).Error
if err == nil {
mu.Lock()
successCount++
mu.Unlock()
}
}(i)
}
wg.Wait()
assert.Equal(t, 1, successCount, "仅 1 个创建应成功")
var count int64
db.Model(&config.Model{}).Where("provider_id = ? AND model_name = ?", provider.ID, modelName).Count(&count)
assert.Equal(t, int64(1), count, "最终应有 1 条记录")
}

View File

@@ -0,0 +1,130 @@
//go:build mysql
package mysql
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"nex/backend/internal/config"
)
func TestConstraint_ForeignKeyEnforced(t *testing.T) {
db := SetupMySQLTestDB(t)
model := config.Model{
ID: "test-model-id",
ProviderID: "non-existent-provider",
ModelName: "gpt-4",
Enabled: true,
}
err := db.Create(&model).Error
assert.Error(t, err, "创建 model 时 provider_id 不存在应失败")
assert.Contains(t, err.Error(), "foreign key constraint", "错误应为外键约束错误")
}
func TestConstraint_CascadeDelete(t *testing.T) {
db := SetupMySQLTestDB(t)
provider := config.Provider{
ID: "test-provider-cascade",
Name: "Test Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
require.NoError(t, err, "创建 provider 应成功")
model := config.Model{
ID: "test-model-cascade",
ProviderID: provider.ID,
ModelName: "gpt-4",
Enabled: true,
}
err = db.Create(&model).Error
require.NoError(t, err, "创建 model 应成功")
err = db.Delete(&provider).Error
require.NoError(t, err, "删除 provider 应成功")
var count int64
err = db.Model(&config.Model{}).Where("provider_id = ?", provider.ID).Count(&count).Error
require.NoError(t, err)
assert.Equal(t, int64(0), count, "删除 provider 后其 models 应被级联删除")
}
func TestConstraint_UniqueProviderModel(t *testing.T) {
db := SetupMySQLTestDB(t)
provider := config.Provider{
ID: "test-provider-unique",
Name: "Test Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
require.NoError(t, err, "创建 provider 应成功")
model1 := config.Model{
ID: "test-model-unique-1",
ProviderID: provider.ID,
ModelName: "gpt-4",
Enabled: true,
}
err = db.Create(&model1).Error
require.NoError(t, err, "创建第一个 model 应成功")
model2 := config.Model{
ID: "test-model-unique-2",
ProviderID: provider.ID,
ModelName: "gpt-4",
Enabled: true,
}
err = db.Create(&model2).Error
assert.Error(t, err, "创建相同 (provider_id, model_name) 的 model 应失败")
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
"错误应为唯一约束错误")
}
func TestConstraint_UniqueUsageStats(t *testing.T) {
db := SetupMySQLTestDB(t)
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
providerID := "test-provider-unique-stats"
stats1 := config.UsageStats{
ProviderID: providerID,
ModelName: "gpt-4",
RequestCount: 10,
Date: todayTime,
}
err := db.Create(&stats1).Error
require.NoError(t, err, "创建第一个 usage_stats 应成功")
stats2 := config.UsageStats{
ProviderID: providerID,
ModelName: "gpt-4",
RequestCount: 20,
Date: todayTime,
}
err = db.Create(&stats2).Error
assert.Error(t, err, "创建相同 (provider_id, model_name, date) 的 usage_stats 应失败")
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
"错误应为唯一约束错误")
}
func containsDuplicateError(errStr string) bool {
return len(errStr) > 0 && (errStr[0:8] == "Error 10" || errStr[0:5] == "Dupli")
}

View File

@@ -0,0 +1,21 @@
version: '3.8'
services:
mysql:
image: mysql:8.0
container_name: nex-mysql-test
environment:
MYSQL_ROOT_PASSWORD: testpass
MYSQL_DATABASE: nex_test
MYSQL_USER: nex_test
MYSQL_PASSWORD: testpass
ports:
- "13306:3306"
tmpfs:
- /var/lib/mysql
healthcheck:
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-p$$MYSQL_ROOT_PASSWORD"]
interval: 1s
timeout: 5s
retries: 10
start_period: 10s

View File

@@ -0,0 +1,126 @@
//go:build mysql
package mysql
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMigration_TablesExist(t *testing.T) {
db := SetupMySQLTestDB(t)
var tables []string
err := db.Raw("SHOW TABLES").Scan(&tables).Error
require.NoError(t, err)
expectedTables := []string{"providers", "models", "usage_stats"}
for _, expected := range expectedTables {
assert.Contains(t, tables, expected, "表 %s 应存在", expected)
}
}
func TestMigration_TableColumns(t *testing.T) {
db := SetupMySQLTestDB(t)
t.Run("providers 表字段", func(t *testing.T) {
var columns []struct {
Field string
Type string
Null string
}
err := db.Raw("SHOW COLUMNS FROM providers").Scan(&columns).Error
require.NoError(t, err)
columnMap := make(map[string]string)
for _, col := range columns {
columnMap[col.Field] = col.Type
}
assert.Contains(t, columnMap["id"], "varchar", "id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["name"], "varchar", "name 应为 VARCHAR 类型")
assert.Contains(t, columnMap["api_key"], "varchar", "api_key 应为 VARCHAR 类型")
assert.Contains(t, columnMap["base_url"], "varchar", "base_url 应为 VARCHAR 类型")
assert.Contains(t, columnMap["protocol"], "varchar", "protocol 应为 VARCHAR 类型")
assert.Contains(t, columnMap["enabled"], "tinyint", "enabled 应为 TINYINT (BOOLEAN) 类型")
assert.Contains(t, columnMap["created_at"], "datetime", "created_at 应为 DATETIME 类型")
assert.Contains(t, columnMap["updated_at"], "datetime", "updated_at 应为 DATETIME 类型")
})
t.Run("models 表字段", func(t *testing.T) {
var columns []struct {
Field string
Type string
}
err := db.Raw("SHOW COLUMNS FROM models").Scan(&columns).Error
require.NoError(t, err)
columnMap := make(map[string]string)
for _, col := range columns {
columnMap[col.Field] = col.Type
}
assert.Contains(t, columnMap["id"], "varchar", "id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["provider_id"], "varchar", "provider_id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["model_name"], "varchar", "model_name 应为 VARCHAR 类型")
assert.Contains(t, columnMap["enabled"], "tinyint", "enabled 应为 TINYINT (BOOLEAN) 类型")
assert.Contains(t, columnMap["created_at"], "datetime", "created_at 应为 DATETIME 类型")
})
t.Run("usage_stats 表字段", func(t *testing.T) {
var columns []struct {
Field string
Type string
}
err := db.Raw("SHOW COLUMNS FROM usage_stats").Scan(&columns).Error
require.NoError(t, err)
columnMap := make(map[string]string)
for _, col := range columns {
columnMap[col.Field] = col.Type
}
assert.Contains(t, columnMap["id"], "int", "id 应为 INT 类型")
assert.Contains(t, columnMap["provider_id"], "varchar", "provider_id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["model_name"], "varchar", "model_name 应为 VARCHAR 类型")
assert.Contains(t, columnMap["request_count"], "int", "request_count 应为 INT 类型")
assert.Contains(t, columnMap["date"], "date", "date 应为 DATE 类型")
})
}
func TestMigration_IndexesExist(t *testing.T) {
db := SetupMySQLTestDB(t)
t.Run("models 表索引", func(t *testing.T) {
var indexes []struct {
KeyName string
}
err := db.Raw("SHOW INDEX FROM models").Scan(&indexes).Error
require.NoError(t, err)
indexMap := make(map[string]bool)
for _, idx := range indexes {
indexMap[idx.KeyName] = true
}
assert.True(t, indexMap["idx_models_provider_id"], "idx_models_provider_id 索引应存在")
assert.True(t, indexMap["idx_models_model_name"], "idx_models_model_name 索引应存在")
})
t.Run("usage_stats 表索引", func(t *testing.T) {
var indexes []struct {
KeyName string
}
err := db.Raw("SHOW INDEX FROM usage_stats").Scan(&indexes).Error
require.NoError(t, err)
indexMap := make(map[string]bool)
for _, idx := range indexes {
indexMap[idx.KeyName] = true
}
assert.True(t, indexMap["idx_usage_stats_provider_model_date"], "idx_usage_stats_provider_model_date 索引应存在")
})
}

View File

@@ -0,0 +1,160 @@
//go:build mysql
package mysql
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/pressly/goose/v3"
"github.com/stretchr/testify/require"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type MySQLTestConfig struct {
Host string
Port int
User string
Password string
Database string
}
func getMySQLTestConfig() *MySQLTestConfig {
return &MySQLTestConfig{
Host: getEnvOrDefault("NEX_TEST_MYSQL_HOST", "localhost"),
Port: getEnvOrDefaultInt("NEX_TEST_MYSQL_PORT", 13306),
User: getEnvOrDefault("NEX_TEST_MYSQL_USER", "nex_test"),
Password: getEnvOrDefault("NEX_TEST_MYSQL_PASSWORD", "testpass"),
Database: getEnvOrDefault("NEX_TEST_MYSQL_DATABASE", "nex_test"),
}
}
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvOrDefaultInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
var intValue int
if _, err := fmt.Sscanf(value, "%d", &intValue); err == nil {
return intValue
}
}
return defaultValue
}
func SkipIfMySQLUnavailable(t *testing.T) {
t.Helper()
cfg := getMySQLTestConfig()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Skipf("MySQL 不可用: %v", err)
}
defer db.Close()
if err := db.Ping(); err != nil {
t.Skipf("MySQL 不可用: %v", err)
}
}
func SetupMySQLTestDB(t *testing.T) *gorm.DB {
t.Helper()
SkipIfMySQLUnavailable(t)
cfg := getMySQLTestConfig()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err, "连接 MySQL 失败")
if err := runMigrations(db); err != nil {
require.NoError(t, err, "运行迁移失败")
}
if err := cleanupTables(db); err != nil {
require.NoError(t, err, "清理表数据失败")
}
sqlDB, err := db.DB()
require.NoError(t, err)
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
sqlDB.SetConnMaxLifetime(time.Hour)
t.Cleanup(func() {
time.Sleep(50 * time.Millisecond)
sqlDB, err := db.DB()
if err == nil {
sqlDB.Close()
}
})
return db
}
func cleanupTables(db *gorm.DB) error {
if err := db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil {
return err
}
if err := db.Exec("TRUNCATE TABLE usage_stats").Error; err != nil {
return err
}
if err := db.Exec("TRUNCATE TABLE models").Error; err != nil {
return err
}
if err := db.Exec("TRUNCATE TABLE providers").Error; err != nil {
return err
}
if err := db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil {
return err
}
return nil
}
func runMigrations(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir()
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
goose.SetDialect("mysql")
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
return nil
}
func getMigrationsDir() string {
_, filename, _, ok := runtime.Caller(0)
if ok {
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", "mysql")
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations/mysql"
}