1
0

Merge branch 'dev-mysql-support' into master

- 新增 MySQL 数据库驱动支持,支持跨设备数据同步
- 新增 MySQL 专项测试能力(并发、约束、迁移)
- 重构迁移目录结构:migrations/sqlite 和 migrations/mysql
- 修复 statsRepo 并发竞态条件,使用 upsert 保证原子性
- Makefile 合并:保留完整命令体系 + 新增 MySQL 测试命令
This commit is contained in:
2026-04-23 16:31:29 +08:00
26 changed files with 1421 additions and 251 deletions

View File

@@ -2,6 +2,7 @@
backend-build backend-run backend-dev backend-test backend-test-unit backend-test-integration backend-test-coverage \
backend-lint backend-clean backend-deps backend-generate \
backend-db-up backend-db-down backend-db-status backend-db-create \
test-mysql-up test-mysql-down test-mysql test-mysql-quick \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \
desktop-dev desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \
@@ -66,17 +67,53 @@ backend-generate:
cd backend && go generate ./...
backend-db-up:
cd backend && goose -dir migrations sqlite3 $(DB_PATH) up
@echo "Running database migration up..."
cd backend && goose -dir migrations/sqlite3 sqlite3 "$(DB_PATH)" up
backend-db-down:
cd backend && goose -dir migrations sqlite3 $(DB_PATH) down
@echo "Running database migration down..."
cd backend && goose -dir migrations/sqlite3 sqlite3 "$(DB_PATH)" down
backend-db-status:
cd backend && goose -dir migrations sqlite3 $(DB_PATH) status
@echo "Checking database migration status..."
cd backend && goose -dir migrations/sqlite3 sqlite3 "$(DB_PATH)" status
backend-db-create:
@read -p "Migration name: " name; \
cd backend && goose -dir migrations create $$name sql
cd backend && goose -dir migrations/sqlite create $$name sql; \
cd backend && goose -dir migrations/mysql create $$name sql
# ============================================
# MySQL 专项测试
# ============================================
test-mysql-up:
@echo "Starting MySQL test container..."
cd backend/tests/mysql && docker-compose up -d
@echo "Waiting for MySQL to be ready..."
@for i in $$(seq 1 30); do \
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
echo "MySQL is ready!"; \
exit 0; \
fi; \
echo "Waiting... ($$i/30)"; \
sleep 1; \
done; \
echo "MySQL failed to start"; \
exit 1
test-mysql-down:
@echo "Stopping MySQL test container..."
cd backend/tests/mysql && docker-compose down -v
test-mysql: test-mysql-up
@echo "Running MySQL tests..."
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
$(MAKE) test-mysql-down
test-mysql-quick:
@echo "Running MySQL tests (without container management)..."
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
# ============================================
# 前端

View File

@@ -66,7 +66,7 @@ nex/
- **语言**: Go 1.26+
- **HTTP 框架**: Gin
- **ORM**: GORM
- **数据库**: SQLite
- **数据库**: SQLite / MySQL
- **日志**: zap + lumberjack结构化日志 + 日志轮转)
- **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值)
- **验证**: go-playground/validator/v10
@@ -204,7 +204,14 @@ server:
write_timeout: 30s
database:
path: ~/.nex/config.db
driver: sqlite # sqlite 或 mysql
path: ~/.nex/config.db # SQLite 数据库文件路径
# --- MySQL 配置driver=mysql 时生效)---
# host: localhost
# port: 3306
# user: nex
# password: ""
# dbname: nex
max_idle_conns: 10
max_open_conns: 100
conn_max_lifetime: 1h
@@ -226,6 +233,14 @@ log:
export NEX_SERVER_PORT=9000
export NEX_DATABASE_PATH=/data/nex.db
export NEX_LOG_LEVEL=debug
# MySQL 模式
export NEX_DATABASE_DRIVER=mysql
export NEX_DATABASE_HOST=db.example.com
export NEX_DATABASE_PORT=3306
export NEX_DATABASE_USER=nex
export NEX_DATABASE_PASSWORD=secret
export NEX_DATABASE_DBNAME=nex
```
命名规则:配置路径转大写 + 下划线(如 `server.port``NEX_SERVER_PORT`)。
@@ -241,7 +256,7 @@ export NEX_LOG_LEVEL=debug
### 数据文件
- `~/.nex/config.yaml` - 配置文件
- `~/.nex/config.db` - SQLite 数据库
- `~/.nex/config.db` - SQLite 数据库MySQL 模式下不使用本地数据库文件)
- `~/.nex/log/` - 日志目录
## 测试

View File

@@ -24,7 +24,7 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
- **语言**: Go 1.26+
- **HTTP 框架**: Gin
- **ORM**: GORM
- **数据库**: SQLite
- **数据库**: SQLite / MySQL
- **日志**: zap + lumberjack
- **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值)
- **验证**: go-playground/validator/v10
@@ -294,7 +294,14 @@ server:
write_timeout: 30s
database:
path: ~/.nex/config.db
driver: sqlite # sqlite 或 mysql
path: ~/.nex/config.db # SQLite 数据库文件路径
# --- MySQL 配置driver=mysql 时生效)---
# host: localhost
# port: 3306
# user: nex
# password: ""
# dbname: nex
max_idle_conns: 10
max_open_conns: 100
conn_max_lifetime: 1h
@@ -316,6 +323,14 @@ log:
export NEX_SERVER_PORT=9000
export NEX_DATABASE_PATH=/data/nex.db
export NEX_LOG_LEVEL=debug
# MySQL 模式
export NEX_DATABASE_DRIVER=mysql
export NEX_DATABASE_HOST=db.example.com
export NEX_DATABASE_PORT=3306
export NEX_DATABASE_USER=nex
export NEX_DATABASE_PASSWORD=secret
export NEX_DATABASE_DBNAME=nex
```
命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port``NEX_SERVER_PORT`)。
@@ -332,7 +347,7 @@ export NEX_LOG_LEVEL=debug
```
服务器: --server-port, --server-read-timeout, --server-write-timeout
数据库: --database-path, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
数据库: --database-driver, --database-path, --database-host, --database-port, --database-user, --database-password, --database-dbname, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress
通用: --config (指定配置文件路径)
```
@@ -352,15 +367,20 @@ export NEX_LOG_LEVEL=debug
# Docker 部署
docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
# MySQL 模式
./server --database-driver mysql --database-host db.example.com --database-user nex --database-password secret --database-dbname nex
# 自定义配置文件
./server --config /path/to/custom.yaml
```
数据文件:
- `~/.nex/config.yaml` - 配置文件
- `~/.nex/config.db` - SQLite 数据库
- `~/.nex/config.db` - SQLite 数据库MySQL 模式下不使用本地数据库文件)
- `~/.nex/log/` - 日志目录
**MySQL 连接说明**MySQL 连接使用 DSN 格式: `user:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=true&loc=Local`,最低支持 MySQL 8.0+。
## 测试
```bash

View File

@@ -17,16 +17,13 @@ import (
"github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"github.com/pressly/goose/v3"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/database"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
@@ -79,12 +76,12 @@ func main() {
}
defer zapLogger.Sync()
db, err := initDatabase(cfg)
db, err := database.Init(&cfg.Database, zapLogger)
if err != nil {
showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err))
os.Exit(1)
}
defer closeDB(db)
defer database.Close(db)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
@@ -159,76 +156,6 @@ func main() {
setupSystray(port)
}
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
dbDir := filepath.Dir(cfg.Database.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, err
}
if err := runMigrations(db); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
return db, nil
}
func runMigrations(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir()
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
goose.SetDialect("sqlite3")
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
return nil
}
func getMigrationsDir() string {
_, filename, _, ok := runtime.Caller(0)
if ok {
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations"
}
func closeDB(db *gorm.DB) {
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.Close()
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
r.Any("/v1/*path", proxyHandler.HandleProxy)

View File

@@ -7,22 +7,17 @@ import (
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/pressly/goose/v3"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/database"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
@@ -32,16 +27,13 @@ import (
)
func main() {
// 1. 加载配置(已包含 CLI 参数解析、环境变量绑定、配置文件读取和验证)
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
// 2. 打印配置摘要
cfg.PrintSummary()
// 3. 初始化日志
zapLogger, err := pkgLogger.New(pkgLogger.Config{
Level: cfg.Log.Level,
Path: cfg.Log.Path,
@@ -55,37 +47,31 @@ func main() {
}
defer zapLogger.Sync()
// 3. 初始化数据库
db, err := initDatabase(cfg)
db, err := database.Init(&cfg.Database, zapLogger)
if err != nil {
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
}
defer closeDB(db)
defer database.Close(db)
// 4. 初始化 repository 层
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
// 5. 初始化缓存
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
if err := routingCache.Preload(); err != nil {
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
}
// 6. 初始化统计缓冲
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
service.WithFlushInterval(5*time.Second),
service.WithFlushThreshold(100))
statsBuffer.Start()
// 7. 初始化 service 层
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 8. 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
@@ -95,16 +81,13 @@ func main() {
}
engine := conversion.NewConversionEngine(registry, zapLogger)
// 9. 初始化 provider client
providerClient := provider.NewClient()
// 10. 初始化 handler 层
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
// 11. 创建 Gin 引擎
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -115,9 +98,8 @@ func main() {
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
// 12. 启动服务器
srv := &http.Server{
Addr: formatAddr(cfg.Server.Port),
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
Handler: r,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
@@ -148,83 +130,9 @@ func main() {
zapLogger.Info("服务器已关闭")
}
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, err
}
if err := runMigrations(db); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
return db, nil
}
func runMigrations(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir()
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
goose.SetDialect("sqlite3")
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
return nil
}
func getMigrationsDir() string {
_, filename, _, ok := runtime.Caller(0)
if ok {
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations"
}
func closeDB(db *gorm.DB) {
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.Close()
}
func formatAddr(port int) string {
return fmt.Sprintf(":%d", port)
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
// 统一代理入口: /{protocol}/{path}
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
// 供应商管理 API
providers := r.Group("/api/providers")
{
providers.GET("", providerHandler.ListProviders)
@@ -234,7 +142,6 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
providers.DELETE("/:id", providerHandler.DeleteProvider)
}
// 模型管理 API
models := r.Group("/api/models")
{
models.GET("", modelHandler.ListModels)
@@ -244,14 +151,12 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
models.DELETE("/:id", modelHandler.DeleteModel)
}
// 统计查询 API
stats := r.Group("/api/stats")
{
stats.GET("", statsHandler.GetStats)
stats.GET("/aggregate", statsHandler.AggregateStats)
}
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})

View File

@@ -24,6 +24,7 @@ require (
go.uber.org/zap v1.27.1
gopkg.in/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
nex/embedfs v0.0.0-00010101000000-000000000000
@@ -32,6 +33,7 @@ require (
require (
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
4d63.com/gochecknoglobals v0.2.2 // indirect
filippo.io/edwards25519 v1.2.0 // indirect
github.com/4meepo/tagalign v1.4.2 // indirect
github.com/Abirdcfly/dupword v0.1.3 // indirect
github.com/Antonboom/errname v1.0.0 // indirect
@@ -90,6 +92,7 @@ require (
github.com/go-critic/go-critic v0.12.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/go-toolsmith/astcast v1.1.0 // indirect
github.com/go-toolsmith/astcopy v1.1.0 // indirect

View File

@@ -35,6 +35,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl
cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs=
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
github.com/4meepo/tagalign v1.4.2 h1:0hcLHPGMjDyM1gHG58cS73aQF8J4TdVR96TZViorO9E=
github.com/4meepo/tagalign v1.4.2/go.mod h1:+p4aMyFM+ra7nb41CnFG6aSDXqRxU/w1VQqScKqDARI=
github.com/Abirdcfly/dupword v0.1.3 h1:9Pa1NuAsZvpFPi9Pqkd93I7LIYRURj+A//dFd5tgBeE=
@@ -206,6 +208,8 @@ github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
@@ -1052,6 +1056,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg=
gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=

View File

@@ -32,7 +32,13 @@ type ServerConfig struct {
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Path string `yaml:"path" mapstructure:"path" validate:"required"`
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"`
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
Password string `yaml:"password" mapstructure:"password"`
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
@@ -61,7 +67,13 @@ func DefaultConfig() *Config {
WriteTimeout: 30 * time.Second,
},
Database: DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(nexDir, "config.db"),
Host: "",
Port: 3306,
User: "",
Password: "",
DBName: "nex",
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 1 * time.Hour,
@@ -117,7 +129,13 @@ func setupDefaults(v *viper.Viper) {
v.SetDefault("server.read_timeout", "30s")
v.SetDefault("server.write_timeout", "30s")
v.SetDefault("database.driver", "sqlite")
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
v.SetDefault("database.host", "")
v.SetDefault("database.port", 3306)
v.SetDefault("database.user", "")
v.SetDefault("database.password", "")
v.SetDefault("database.dbname", "nex")
v.SetDefault("database.max_idle_conns", 10)
v.SetDefault("database.max_open_conns", 100)
v.SetDefault("database.conn_max_lifetime", "1h")
@@ -138,7 +156,13 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
flagSet.Duration("server-read-timeout", 0, "读超时")
flagSet.Duration("server-write-timeout", 0, "写超时")
flagSet.String("database-driver", "", "数据库驱动sqlite/mysql")
flagSet.String("database-path", "", "数据库文件路径")
flagSet.String("database-host", "", "MySQL 主机地址")
flagSet.Int("database-port", 0, "MySQL 端口")
flagSet.String("database-user", "", "MySQL 用户名")
flagSet.String("database-password", "", "MySQL 密码")
flagSet.String("database-dbname", "", "MySQL 数据库名")
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
@@ -156,7 +180,13 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
@@ -268,7 +298,7 @@ func SaveConfig(cfg *Config) error {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return os.WriteFile(configPath, data, 0644)
return os.WriteFile(configPath, data, 0600)
}
// Validate validates the config
@@ -285,7 +315,13 @@ func (c *Config) PrintSummary() {
fmt.Println("\nAI Gateway 启动配置")
fmt.Println("==================")
fmt.Printf("服务器端口: %d\n", c.Server.Port)
fmt.Printf("数据库路径: %s\n", c.Database.Path)
if c.Database.Driver == "mysql" {
fmt.Printf("数据库类型: mysql\n")
fmt.Printf("数据库地址: %s:%d/%s\n", c.Database.Host, c.Database.Port, c.Database.DBName)
} else {
fmt.Printf("数据库类型: sqlite\n")
fmt.Printf("数据库路径: %s\n", c.Database.Path)
}
fmt.Printf("日志级别: %s\n", c.Log.Level)
fmt.Println("\n配置来源:")
configPath, _ := GetConfigPath()

View File

@@ -19,6 +19,12 @@ func TestDefaultConfig(t *testing.T) {
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, "sqlite", cfg.Database.Driver)
assert.Equal(t, "", cfg.Database.Host)
assert.Equal(t, 3306, cfg.Database.Port)
assert.Equal(t, "", cfg.Database.User)
assert.Equal(t, "", cfg.Database.Password)
assert.Equal(t, "nex", cfg.Database.DBName)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
@@ -86,11 +92,76 @@ func TestConfig_Validate(t *testing.T) {
wantErr: false,
},
{
name: "数据库路径为空无效",
name: "SQLite模式路径为空无效",
modify: func(c *Config) { c.Database.Path = "" },
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "driver值不合法",
modify: func(c *Config) { c.Database.Driver = "postgres" },
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL配置有效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.Port = 3306
c.Database.User = "root"
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: false,
},
{
name: "MySQL模式host为空无效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = ""
c.Database.User = "root"
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL模式user为空无效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.User = ""
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL模式dbname为空无效",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.User = "root"
c.Database.DBName = ""
c.Database.Path = ""
},
wantErr: true,
errMsg: "配置验证失败",
},
{
name: "MySQL模式忽略path字段",
modify: func(c *Config) {
c.Database.Driver = "mysql"
c.Database.Host = "localhost"
c.Database.User = "root"
c.Database.DBName = "nex"
c.Database.Path = ""
},
wantErr: false,
},
}
for _, tt := range tests {
@@ -140,7 +211,10 @@ func TestSaveAndLoadConfig(t *testing.T) {
WriteTimeout: 20 * time.Second,
},
Database: DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(dir, "test.db"),
Port: 3306,
DBName: "nex",
MaxIdleConns: 5,
MaxOpenConns: 50,
ConnMaxLifetime: 30 * time.Minute,
@@ -210,6 +284,9 @@ func TestConfigPriority(t *testing.T) {
assert.Equal(t, 9826, cfg.Server.Port)
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, "sqlite", cfg.Database.Driver)
assert.Equal(t, 3306, cfg.Database.Port)
assert.Equal(t, "nex", cfg.Database.DBName)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
@@ -222,11 +299,19 @@ func TestConfigPriority(t *testing.T) {
}
func TestPrintSummary(t *testing.T) {
// 测试配置摘要输出
t.Run("打印配置摘要", func(t *testing.T) {
t.Run("SQLite模式摘要", func(t *testing.T) {
cfg := DefaultConfig()
// PrintSummary 只是打印,不会返回错误
// 这里主要验证不会 panic
assert.NotPanics(t, func() {
cfg.PrintSummary()
})
})
t.Run("MySQL模式摘要", func(t *testing.T) {
cfg := DefaultConfig()
cfg.Database.Driver = "mysql"
cfg.Database.Host = "db.example.com"
cfg.Database.Port = 3306
cfg.Database.User = "nex"
cfg.Database.DBName = "nex"
assert.NotPanics(t, func() {
cfg.PrintSummary()
})

View File

@@ -29,8 +29,8 @@ type Model struct {
// UsageStats 用量统计
type UsageStats struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
RequestCount int `gorm:"default:0" json:"request_count"`
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
}

View File

@@ -0,0 +1,126 @@
package database
import (
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"github.com/pressly/goose/v3"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"nex/backend/internal/config"
)
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
db, err := initDB(cfg)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
if err := runMigrations(db, cfg.Driver); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
configurePool(db, cfg)
return db, nil
}
func Close(db *gorm.DB) {
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.Close()
}
func initDB(cfg *config.DatabaseConfig) (*gorm.DB, error) {
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
}
switch cfg.Driver {
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
return gorm.Open(mysql.Open(dsn), gormConfig)
default:
dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
}
}
func runMigrations(db *gorm.DB, driver string) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir(driver)
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
gooseDialect := "sqlite3"
migrationsSubDir := "sqlite"
if driver == "mysql" {
gooseDialect = "mysql"
migrationsSubDir = "mysql"
}
goose.SetDialect(gooseDialect)
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
log.Printf("使用 %s 方言执行迁移,目录: %s", gooseDialect, migrationsSubDir)
return nil
}
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) {
if cfg.Driver == "sqlite" {
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
}
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
cfg.MaxIdleConns, cfg.MaxOpenConns, cfg.ConnMaxLifetime)
}
func getMigrationsDir(driver string) string {
_, filename, _, ok := runtime.Caller(0)
if ok {
subDir := "sqlite"
if driver == "mysql" {
subDir = "mysql"
}
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", subDir)
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations"
}
func BuildDSN(cfg *config.DatabaseConfig) string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
}

View File

@@ -0,0 +1,75 @@
package database
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
)
func TestInit_SQLite(t *testing.T) {
dir := t.TempDir()
cfg := &config.DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(dir, "test.db"),
MaxIdleConns: 5,
MaxOpenConns: 10,
ConnMaxLifetime: 0,
}
db, err := Init(cfg, nil)
require.NoError(t, err)
require.NotNil(t, db)
defer Close(db)
sqlDB, err := db.DB()
require.NoError(t, err)
require.NotNil(t, sqlDB)
}
func TestClose(t *testing.T) {
dir := t.TempDir()
cfg := &config.DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(dir, "test.db"),
MaxIdleConns: 5,
MaxOpenConns: 10,
ConnMaxLifetime: 0,
}
db, err := Init(cfg, nil)
require.NoError(t, err)
require.NotNil(t, db)
Close(db)
}
func TestBuildDSN(t *testing.T) {
cfg := &config.DatabaseConfig{
Driver: "mysql",
Host: "db.example.com",
Port: 3306,
User: "nexuser",
Password: "secretpass",
DBName: "nexdb",
}
dsn := BuildDSN(cfg)
assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn)
}
func TestBuildDSN_EmptyPassword(t *testing.T) {
cfg := &config.DatabaseConfig{
Driver: "mysql",
Host: "localhost",
Port: 3306,
User: "root",
DBName: "nex",
}
dsn := BuildDSN(cfg)
assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn)
}

View File

@@ -1,10 +1,10 @@
package repository
import (
"errors"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"nex/backend/internal/config"
"nex/backend/internal/domain"
@@ -22,47 +22,43 @@ func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).First(&stats).Error
stats := config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
if errors.Is(err, gorm.ErrRecordNotFound) {
stats = config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
}
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + 1"),
}),
}).Create(&stats).Error
}
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, date).First(&stats).Error
stats := config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return tx.Create(&config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}).Error
} else if err != nil {
return err
}
return tx.Model(&stats).
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
})
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + ?", delta),
}),
}).Create(&stats).Error
}
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {

View File

@@ -0,0 +1,44 @@
-- +goose Up
-- MySQL 方言初始迁移providers、models、usage_stats 完整表结构
CREATE TABLE IF NOT EXISTS providers (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL,
api_key VARCHAR(255) NOT NULL,
base_url VARCHAR(255) NOT NULL,
protocol VARCHAR(50) DEFAULT 'openai',
enabled BOOLEAN DEFAULT TRUE,
created_at DATETIME(3),
updated_at DATETIME(3)
);
CREATE TABLE IF NOT EXISTS models (
id VARCHAR(36) PRIMARY KEY,
provider_id VARCHAR(36) NOT NULL,
model_name VARCHAR(255) NOT NULL,
enabled BOOLEAN DEFAULT TRUE,
created_at DATETIME(3),
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE,
UNIQUE(provider_id, model_name)
);
CREATE TABLE IF NOT EXISTS usage_stats (
id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
provider_id VARCHAR(36) NOT NULL,
model_name VARCHAR(255) NOT NULL,
request_count INT DEFAULT 0,
date DATE NOT NULL,
UNIQUE(provider_id, model_name, date)
);
CREATE INDEX idx_models_provider_id ON models(provider_id);
CREATE INDEX idx_models_model_name ON models(model_name);
CREATE INDEX idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date);
-- +goose Down
DROP INDEX IF EXISTS idx_usage_stats_provider_model_date ON usage_stats;
DROP INDEX IF EXISTS idx_models_model_name ON models;
DROP INDEX IF EXISTS idx_models_provider_id ON models;
DROP TABLE IF EXISTS usage_stats;
DROP TABLE IF EXISTS models;
DROP TABLE IF EXISTS providers;

View File

@@ -158,7 +158,10 @@ func TestSaveAndLoadConfig(t *testing.T) {
WriteTimeout: 45 * time.Second,
},
Database: config.DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(tmpDir, "test.db"),
Port: 3306,
DBName: "nex",
MaxIdleConns: 15,
MaxOpenConns: 150,
ConnMaxLifetime: 2 * time.Hour,

View File

@@ -0,0 +1,158 @@
//go:build mysql
package mysql
import (
"sync"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
"nex/backend/internal/repository"
)
func TestConcurrent_UsageStatsRecord(t *testing.T) {
db := SetupMySQLTestDB(t)
statsRepo := repository.NewStatsRepository(db)
providerID := "concurrent-test-provider"
modelName := "gpt-4"
concurrency := 10
var wg sync.WaitGroup
wg.Add(concurrency)
errChan := make(chan error, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
err := statsRepo.Record(providerID, modelName)
if err != nil {
errChan <- err
}
}()
}
wg.Wait()
close(errChan)
var errorCount int
uniqueErrors := make(map[string]int)
for err := range errChan {
errorCount++
uniqueErrors[err.Error()]++
}
t.Logf("并发 %d 次,错误 %d 次", concurrency, errorCount)
for errMsg, count := range uniqueErrors {
t.Logf(" 错误: %s (出现 %d 次)", errMsg, count)
}
var stats config.UsageStats
err := db.Where("provider_id = ? AND model_name = ?", providerID, modelName).
First(&stats).Error
require.NoError(t, err, "应能查到 usage_stats 记录")
successCount := concurrency - errorCount
t.Logf("成功次数: %d, 最终 request_count: %d", successCount, stats.RequestCount)
assert.Equal(t, concurrency, stats.RequestCount, "request_count 应等于并发数,无数据丢失或重复")
}
func TestConcurrent_ProviderCreate(t *testing.T) {
db := SetupMySQLTestDB(t)
providerID := "concurrent-provider-id"
concurrency := 10
var wg sync.WaitGroup
wg.Add(concurrency)
successCount := 0
var mu sync.Mutex
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
provider := config.Provider{
ID: providerID,
Name: "Concurrent Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
if err == nil {
mu.Lock()
successCount++
mu.Unlock()
}
}()
}
wg.Wait()
assert.Equal(t, 1, successCount, "仅 1 个创建应成功")
var count int64
db.Model(&config.Provider{}).Where("id = ?", providerID).Count(&count)
assert.Equal(t, int64(1), count, "最终应有 1 条记录")
}
func TestConcurrent_ModelCreate(t *testing.T) {
db := SetupMySQLTestDB(t)
provider := config.Provider{
ID: "concurrent-model-provider",
Name: "Test Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
require.NoError(t, err, "创建 provider 应成功")
modelName := "gpt-4-concurrent"
concurrency := 10
var wg sync.WaitGroup
wg.Add(concurrency)
successCount := 0
var mu sync.Mutex
for i := 0; i < concurrency; i++ {
go func(idx int) {
defer wg.Done()
model := config.Model{
ID: uuid.New().String(),
ProviderID: provider.ID,
ModelName: modelName,
Enabled: true,
}
err := db.Create(&model).Error
if err == nil {
mu.Lock()
successCount++
mu.Unlock()
}
}(i)
}
wg.Wait()
assert.Equal(t, 1, successCount, "仅 1 个创建应成功")
var count int64
db.Model(&config.Model{}).Where("provider_id = ? AND model_name = ?", provider.ID, modelName).Count(&count)
assert.Equal(t, int64(1), count, "最终应有 1 条记录")
}

View File

@@ -0,0 +1,130 @@
//go:build mysql
package mysql
import (
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"nex/backend/internal/config"
)
func TestConstraint_ForeignKeyEnforced(t *testing.T) {
db := SetupMySQLTestDB(t)
model := config.Model{
ID: "test-model-id",
ProviderID: "non-existent-provider",
ModelName: "gpt-4",
Enabled: true,
}
err := db.Create(&model).Error
assert.Error(t, err, "创建 model 时 provider_id 不存在应失败")
assert.Contains(t, err.Error(), "foreign key constraint", "错误应为外键约束错误")
}
func TestConstraint_CascadeDelete(t *testing.T) {
db := SetupMySQLTestDB(t)
provider := config.Provider{
ID: "test-provider-cascade",
Name: "Test Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
require.NoError(t, err, "创建 provider 应成功")
model := config.Model{
ID: "test-model-cascade",
ProviderID: provider.ID,
ModelName: "gpt-4",
Enabled: true,
}
err = db.Create(&model).Error
require.NoError(t, err, "创建 model 应成功")
err = db.Delete(&provider).Error
require.NoError(t, err, "删除 provider 应成功")
var count int64
err = db.Model(&config.Model{}).Where("provider_id = ?", provider.ID).Count(&count).Error
require.NoError(t, err)
assert.Equal(t, int64(0), count, "删除 provider 后其 models 应被级联删除")
}
func TestConstraint_UniqueProviderModel(t *testing.T) {
db := SetupMySQLTestDB(t)
provider := config.Provider{
ID: "test-provider-unique",
Name: "Test Provider",
APIKey: "test-key",
BaseURL: "https://test.com",
Enabled: true,
}
err := db.Create(&provider).Error
require.NoError(t, err, "创建 provider 应成功")
model1 := config.Model{
ID: "test-model-unique-1",
ProviderID: provider.ID,
ModelName: "gpt-4",
Enabled: true,
}
err = db.Create(&model1).Error
require.NoError(t, err, "创建第一个 model 应成功")
model2 := config.Model{
ID: "test-model-unique-2",
ProviderID: provider.ID,
ModelName: "gpt-4",
Enabled: true,
}
err = db.Create(&model2).Error
assert.Error(t, err, "创建相同 (provider_id, model_name) 的 model 应失败")
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
"错误应为唯一约束错误")
}
func TestConstraint_UniqueUsageStats(t *testing.T) {
db := SetupMySQLTestDB(t)
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
providerID := "test-provider-unique-stats"
stats1 := config.UsageStats{
ProviderID: providerID,
ModelName: "gpt-4",
RequestCount: 10,
Date: todayTime,
}
err := db.Create(&stats1).Error
require.NoError(t, err, "创建第一个 usage_stats 应成功")
stats2 := config.UsageStats{
ProviderID: providerID,
ModelName: "gpt-4",
RequestCount: 20,
Date: todayTime,
}
err = db.Create(&stats2).Error
assert.Error(t, err, "创建相同 (provider_id, model_name, date) 的 usage_stats 应失败")
assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) ||
(err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))),
"错误应为唯一约束错误")
}
func containsDuplicateError(errStr string) bool {
return len(errStr) > 0 && (errStr[0:8] == "Error 10" || errStr[0:5] == "Dupli")
}

View File

@@ -0,0 +1,21 @@
version: '3.8'
services:
mysql:
image: mysql:8.0
container_name: nex-mysql-test
environment:
MYSQL_ROOT_PASSWORD: testpass
MYSQL_DATABASE: nex_test
MYSQL_USER: nex_test
MYSQL_PASSWORD: testpass
ports:
- "13306:3306"
tmpfs:
- /var/lib/mysql
healthcheck:
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-p$$MYSQL_ROOT_PASSWORD"]
interval: 1s
timeout: 5s
retries: 10
start_period: 10s

View File

@@ -0,0 +1,126 @@
//go:build mysql
package mysql
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMigration_TablesExist(t *testing.T) {
db := SetupMySQLTestDB(t)
var tables []string
err := db.Raw("SHOW TABLES").Scan(&tables).Error
require.NoError(t, err)
expectedTables := []string{"providers", "models", "usage_stats"}
for _, expected := range expectedTables {
assert.Contains(t, tables, expected, "表 %s 应存在", expected)
}
}
func TestMigration_TableColumns(t *testing.T) {
db := SetupMySQLTestDB(t)
t.Run("providers 表字段", func(t *testing.T) {
var columns []struct {
Field string
Type string
Null string
}
err := db.Raw("SHOW COLUMNS FROM providers").Scan(&columns).Error
require.NoError(t, err)
columnMap := make(map[string]string)
for _, col := range columns {
columnMap[col.Field] = col.Type
}
assert.Contains(t, columnMap["id"], "varchar", "id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["name"], "varchar", "name 应为 VARCHAR 类型")
assert.Contains(t, columnMap["api_key"], "varchar", "api_key 应为 VARCHAR 类型")
assert.Contains(t, columnMap["base_url"], "varchar", "base_url 应为 VARCHAR 类型")
assert.Contains(t, columnMap["protocol"], "varchar", "protocol 应为 VARCHAR 类型")
assert.Contains(t, columnMap["enabled"], "tinyint", "enabled 应为 TINYINT (BOOLEAN) 类型")
assert.Contains(t, columnMap["created_at"], "datetime", "created_at 应为 DATETIME 类型")
assert.Contains(t, columnMap["updated_at"], "datetime", "updated_at 应为 DATETIME 类型")
})
t.Run("models 表字段", func(t *testing.T) {
var columns []struct {
Field string
Type string
}
err := db.Raw("SHOW COLUMNS FROM models").Scan(&columns).Error
require.NoError(t, err)
columnMap := make(map[string]string)
for _, col := range columns {
columnMap[col.Field] = col.Type
}
assert.Contains(t, columnMap["id"], "varchar", "id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["provider_id"], "varchar", "provider_id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["model_name"], "varchar", "model_name 应为 VARCHAR 类型")
assert.Contains(t, columnMap["enabled"], "tinyint", "enabled 应为 TINYINT (BOOLEAN) 类型")
assert.Contains(t, columnMap["created_at"], "datetime", "created_at 应为 DATETIME 类型")
})
t.Run("usage_stats 表字段", func(t *testing.T) {
var columns []struct {
Field string
Type string
}
err := db.Raw("SHOW COLUMNS FROM usage_stats").Scan(&columns).Error
require.NoError(t, err)
columnMap := make(map[string]string)
for _, col := range columns {
columnMap[col.Field] = col.Type
}
assert.Contains(t, columnMap["id"], "int", "id 应为 INT 类型")
assert.Contains(t, columnMap["provider_id"], "varchar", "provider_id 应为 VARCHAR 类型")
assert.Contains(t, columnMap["model_name"], "varchar", "model_name 应为 VARCHAR 类型")
assert.Contains(t, columnMap["request_count"], "int", "request_count 应为 INT 类型")
assert.Contains(t, columnMap["date"], "date", "date 应为 DATE 类型")
})
}
func TestMigration_IndexesExist(t *testing.T) {
db := SetupMySQLTestDB(t)
t.Run("models 表索引", func(t *testing.T) {
var indexes []struct {
KeyName string
}
err := db.Raw("SHOW INDEX FROM models").Scan(&indexes).Error
require.NoError(t, err)
indexMap := make(map[string]bool)
for _, idx := range indexes {
indexMap[idx.KeyName] = true
}
assert.True(t, indexMap["idx_models_provider_id"], "idx_models_provider_id 索引应存在")
assert.True(t, indexMap["idx_models_model_name"], "idx_models_model_name 索引应存在")
})
t.Run("usage_stats 表索引", func(t *testing.T) {
var indexes []struct {
KeyName string
}
err := db.Raw("SHOW INDEX FROM usage_stats").Scan(&indexes).Error
require.NoError(t, err)
indexMap := make(map[string]bool)
for _, idx := range indexes {
indexMap[idx.KeyName] = true
}
assert.True(t, indexMap["idx_usage_stats_provider_model_date"], "idx_usage_stats_provider_model_date 索引应存在")
})
}

View File

@@ -0,0 +1,160 @@
//go:build mysql
package mysql
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/pressly/goose/v3"
"github.com/stretchr/testify/require"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type MySQLTestConfig struct {
Host string
Port int
User string
Password string
Database string
}
func getMySQLTestConfig() *MySQLTestConfig {
return &MySQLTestConfig{
Host: getEnvOrDefault("NEX_TEST_MYSQL_HOST", "localhost"),
Port: getEnvOrDefaultInt("NEX_TEST_MYSQL_PORT", 13306),
User: getEnvOrDefault("NEX_TEST_MYSQL_USER", "nex_test"),
Password: getEnvOrDefault("NEX_TEST_MYSQL_PASSWORD", "testpass"),
Database: getEnvOrDefault("NEX_TEST_MYSQL_DATABASE", "nex_test"),
}
}
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvOrDefaultInt(key string, defaultValue int) int {
if value := os.Getenv(key); value != "" {
var intValue int
if _, err := fmt.Sscanf(value, "%d", &intValue); err == nil {
return intValue
}
}
return defaultValue
}
func SkipIfMySQLUnavailable(t *testing.T) {
t.Helper()
cfg := getMySQLTestConfig()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Skipf("MySQL 不可用: %v", err)
}
defer db.Close()
if err := db.Ping(); err != nil {
t.Skipf("MySQL 不可用: %v", err)
}
}
func SetupMySQLTestDB(t *testing.T) *gorm.DB {
t.Helper()
SkipIfMySQLUnavailable(t)
cfg := getMySQLTestConfig()
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err, "连接 MySQL 失败")
if err := runMigrations(db); err != nil {
require.NoError(t, err, "运行迁移失败")
}
if err := cleanupTables(db); err != nil {
require.NoError(t, err, "清理表数据失败")
}
sqlDB, err := db.DB()
require.NoError(t, err)
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
sqlDB.SetConnMaxLifetime(time.Hour)
t.Cleanup(func() {
time.Sleep(50 * time.Millisecond)
sqlDB, err := db.DB()
if err == nil {
sqlDB.Close()
}
})
return db
}
func cleanupTables(db *gorm.DB) error {
if err := db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil {
return err
}
if err := db.Exec("TRUNCATE TABLE usage_stats").Error; err != nil {
return err
}
if err := db.Exec("TRUNCATE TABLE models").Error; err != nil {
return err
}
if err := db.Exec("TRUNCATE TABLE providers").Error; err != nil {
return err
}
if err := db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil {
return err
}
return nil
}
func runMigrations(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir()
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
goose.SetDialect("mysql")
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
return nil
}
func getMigrationsDir() string {
_, filename, _, ok := runtime.Caller(0)
if ok {
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", "mysql")
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations/mysql"
}

View File

@@ -77,6 +77,7 @@
- **THEN** SHALL 支持 `required` 规则
- **THEN** SHALL 支持 `min``max` 规则
- **THEN** SHALL 支持 `oneof` 规则
- **THEN** SHALL 支持 `required_if` 条件验证规则
#### Scenario: 验证执行
@@ -85,6 +86,17 @@
- **THEN** SHALL 返回验证错误
- **THEN** SHALL NOT 启动应用(如果验证失败)
#### Scenario: 数据库驱动条件验证
- **WHEN** `database.driver``sqlite`
- **THEN** SHALL 验证 `database.path` 必填
- **THEN** SHALL NOT 要求 MySQL 字段host/port/user/password/dbname
- **WHEN** `database.driver``mysql`
- **THEN** SHALL 验证 `database.host` 必填
- **THEN** SHALL 验证 `database.user` 必填
- **THEN** SHALL 验证 `database.dbname` 必填
- **THEN** SHALL NOT 要求 `database.path`
### Requirement: 配置结构定义
系统 SHALL 定义清晰的配置结构。
@@ -98,7 +110,14 @@
#### Scenario: Database 配置
- **WHEN** 加载 database 配置
- **THEN** SHALL 包含 path、max_idle_conns、max_open_conns、conn_max_lifetime 字段
- **THEN** SHALL 包含 driver 字段(值为 `sqlite``mysql`,默认 `sqlite`
- **THEN** SHALL 包含 path 字段SQLite 模式下的数据库文件路径)
- **THEN** SHALL 包含 host 字段MySQL 主机地址)
- **THEN** SHALL 包含 port 字段MySQL 端口,默认 3306
- **THEN** SHALL 包含 user 字段MySQL 用户名)
- **THEN** SHALL 包含 password 字段MySQL 密码,选填)
- **THEN** SHALL 包含 dbname 字段MySQL 数据库名)
- **THEN** SHALL 包含 max_idle_conns、max_open_conns、conn_max_lifetime 字段
- **THEN** SHALL 使用合理的默认值
#### Scenario: Log 配置
@@ -121,7 +140,13 @@
#### Scenario: Database 默认值
- **WHEN** 使用默认配置
- **THEN** database.driver SHALL 为 `sqlite`
- **THEN** database.path SHALL 为 `~/.nex/config.db`
- **THEN** database.host SHALL 为空字符串
- **THEN** database.port SHALL 为 3306
- **THEN** database.user SHALL 为空字符串
- **THEN** database.password SHALL 为空字符串
- **THEN** database.dbname SHALL 为 `nex`
- **THEN** database.max_idle_conns SHALL 为 10
- **THEN** database.max_open_conns SHALL 为 100
- **THEN** database.conn_max_lifetime SHALL 为 1h
@@ -248,18 +273,38 @@
- **THEN** SHALL 在日志中记录覆盖信息
- **THEN** SHALL 显示被覆盖的配置项名称
### Requirement: 配置文件安全
系统 SHALL 使用安全的文件权限保存配置文件。
#### Scenario: 配置文件权限
- **WHEN** 保存配置文件(`SaveConfig`
- **THEN** SHALL 使用 `0600` 权限写入文件(仅 owner 可读写)
- **THEN** SHALL 防止其他用户读取配置中的 MySQL 密码等敏感信息
### Requirement: 配置摘要输出
系统 SHALL 在启动时输出配置摘要。
#### Scenario: 摘要内容
#### Scenario: SQLite 模式摘要
- **WHEN** 配置加载完成
- **WHEN** `database.driver``sqlite`
- **THEN** SHALL 打印关键配置项(端口、数据库路径、日志级别等)
- **THEN** SHALL 打印配置文件路径
- **THEN** SHALL 打印环境变量数量
- **THEN** SHALL 打印 CLI 参数数量
#### Scenario: MySQL 模式摘要
- **WHEN** `database.driver``mysql`
- **THEN** SHALL 打印关键配置项(端口、数据库类型、数据库地址、日志级别等)
- **THEN** SHALL 打印数据库地址格式为 `{host}:{port}/{dbname}`
- **THEN** SHALL 不打印密码
- **THEN** SHALL 打印配置文件路径
- **THEN** SHALL 打印环境变量数量
- **THEN** SHALL 打印 CLI 参数数量
#### Scenario: 摘要格式
- **WHEN** 打印配置摘要
@@ -297,7 +342,7 @@
- **WHEN** 使用服务器相关参数
- **THEN** SHALL 支持 `--server-port``--server-read-timeout``--server-write-timeout`
- **WHEN** 使用数据库相关参数
- **THEN** SHALL 支持 `--database-path``--database-max-idle-conns``--database-max-open-conns``--database-conn-max-lifetime`
- **THEN** SHALL 支持 `--database-driver``--database-path``--database-host``--database-port``--database-user``--database-password``--database-dbname``--database-max-idle-conns``--database-max-open-conns``--database-conn-max-lifetime`
- **WHEN** 使用日志相关参数
- **THEN** SHALL 支持 `--log-level``--log-path``--log-max-size``--log-max-backups``--log-max-age``--log-compress`
@@ -348,7 +393,7 @@
- **WHEN** 设置服务器相关环境变量
- **THEN** SHALL 支持 `NEX_SERVER_PORT``NEX_SERVER_READ_TIMEOUT``NEX_SERVER_WRITE_TIMEOUT`
- **WHEN** 设置数据库相关环境变量
- **THEN** SHALL 支持 `NEX_DATABASE_PATH``NEX_DATABASE_MAX_IDLE_CONNS``NEX_DATABASE_MAX_OPEN_CONNS``NEX_DATABASE_CONN_MAX_LIFETIME`
- **THEN** SHALL 支持 `NEX_DATABASE_DRIVER``NEX_DATABASE_PATH``NEX_DATABASE_HOST``NEX_DATABASE_PORT``NEX_DATABASE_USER``NEX_DATABASE_PASSWORD``NEX_DATABASE_DBNAME``NEX_DATABASE_MAX_IDLE_CONNS``NEX_DATABASE_MAX_OPEN_CONNS``NEX_DATABASE_CONN_MAX_LIFETIME`
- **WHEN** 设置日志相关环境变量
- **THEN** SHALL 支持 `NEX_LOG_LEVEL``NEX_LOG_PATH``NEX_LOG_MAX_SIZE``NEX_LOG_MAX_BACKUPS``NEX_LOG_MAX_AGE``NEX_LOG_COMPRESS`

View File

@@ -46,6 +46,14 @@
- **THEN** SHALL 删除所有表和索引
- **THEN** SHALL 按正确顺序删除(避免外键约束错误)
#### Scenario: 按数据库方言拆分迁移目录
- **WHEN** 组织迁移文件
- **THEN** SHALL 将 SQLite 方言迁移文件存储在 `migrations/sqlite/` 目录
- **THEN** SHALL 将 MySQL 方言迁移文件存储在 `migrations/mysql/` 目录
- **THEN** SHALL 两个目录维护独立的版本号序列
- **THEN** SHALL 两个目录的迁移文件内容在逻辑上一致(相同的表结构和约束),但使用各自方言的 DDL
### Requirement: models 表 schema 变更
系统 SHALL 在初始迁移脚本中直接创建新的 models 表结构(服务未上线,无需考虑数据迁移,迁移脚本已合并为单个初始迁移文件)。
@@ -63,28 +71,37 @@
#### Scenario: 迁移 up 命令
- **WHEN** 执行 `make migrate-up`
- **WHEN** 执行 `make backend-migrate-up`
- **THEN** SHALL 执行所有待执行的迁移
- **THEN** SHALL 使用 `DB_DRIVER` 变量选择方言目录(默认 `sqlite3`
- **THEN** SHALL 使用 `DB_DSN` 变量作为数据库连接串
- **THEN** SHALL 显示迁移进度
#### Scenario: 迁移 down 命令
- **WHEN** 执行 `make migrate-down`
- **WHEN** 执行 `make backend-migrate-down`
- **THEN** SHALL 回滚最后一个迁移
- **THEN** SHALL 使用 `DB_DRIVER``DB_DSN` 变量
- **THEN** SHALL 显示回滚进度
#### Scenario: 迁移状态命令
- **WHEN** 执行 `make migrate-status`
- **WHEN** 执行 `make backend-migrate-status`
- **THEN** SHALL 显示当前迁移状态
- **THEN** SHALL 显示已执行和待执行的迁移
#### Scenario: 创建迁移命令
- **WHEN** 执行 `make migrate-create name=<name>`
- **THEN** SHALL 创建新的迁移文件模板
- **WHEN** 执行 `make backend-migrate-create`
- **THEN** SHALL 同时在 `migrations/sqlite/``migrations/mysql/` 两个目录创建新的迁移文件模板
- **THEN** SHALL 使用递增的版本号
#### Scenario: MySQL 迁移命令使用
- **WHEN** 使用 MySQL 驱动执行迁移
- **THEN** SHALL 设置 `DB_DRIVER=mysql`
- **THEN** SHALL 设置 `DB_DSN` 为 MySQL 连接串(如 `user:pass@tcp(localhost:3306)/nex`
### Requirement: 应用启动时迁移
应用 SHALL 在启动时执行迁移。
@@ -92,6 +109,9 @@
#### Scenario: 自动迁移
- **WHEN** 应用启动
- **THEN** SHALL 根据 `database.driver` 配置选择对应的迁移目录和 goose dialect
- **THEN** SHALL 在 `driver=sqlite` 时使用 `migrations/sqlite/` 目录goose dialect 为 `sqlite3`
- **THEN** SHALL 在 `driver=mysql` 时使用 `migrations/mysql/` 目录goose dialect 为 `mysql`
- **THEN** SHALL 自动执行待执行的迁移
- **THEN** SHALL 在迁移失败时拒绝启动
- **THEN** SHALL 记录迁移日志
@@ -149,5 +169,5 @@
#### Scenario: 迁移文件存储
- **WHEN** 创建迁移文件
- **THEN** SHALL 存储在 migrations/ 目录
- **THEN** SHALL 按 SQL 方言存储在对应子目录(`migrations/sqlite/``migrations/mysql/`
- **THEN** SHALL 提交到版本控制系统

View File

@@ -0,0 +1,107 @@
# MySQL Driver
## Purpose
支持 MySQL 作为可选数据库后端,通过配置选择 sqlite 或 mysql 驱动,提供 MySQL 连接管理、初始化和方言迁移文件。
## Requirements
### Requirement: MySQL 数据库驱动支持
系统 SHALL 支持通过配置项 `database.driver` 选择 `sqlite``mysql` 数据库驱动,默认值为 `sqlite`
#### Scenario: 默认使用 SQLite 驱动
- **WHEN** 配置中未指定 `database.driver`
- **THEN** SHALL 使用 `sqlite` 作为数据库驱动
- **THEN** SHALL 行为与现有逻辑完全一致
#### Scenario: 配置 MySQL 驱动
- **WHEN** 配置 `database.driver` 设置为 `mysql`
- **THEN** SHALL 使用 MySQL 驱动连接远程数据库
- **THEN** SHALL 使用 `gorm.io/driver/mysql` 打开连接
- **THEN** SHALL 构建 DSN 格式为 `{user}:{password}@tcp({host}:{port})/{dbname}?charset=utf8mb4&parseTime=true&loc=Local`
#### Scenario: driver 值不合法
- **WHEN** 配置 `database.driver` 不是 `sqlite``mysql`
- **THEN** SHALL 配置验证失败,拒绝启动
### Requirement: MySQL 连接配置
系统 SHALL 在 `DatabaseConfig` 中支持 MySQL 连接参数。
#### Scenario: MySQL 连接参数字段
- **WHEN** `database.driver``mysql`
- **THEN** SHALL 读取 `host`MySQL 主机地址,必填)
- **THEN** SHALL 读取 `port`MySQL 端口,默认 3306
- **THEN** SHALL 读取 `user`MySQL 用户名,必填)
- **THEN** SHALL 读取 `password`MySQL 密码,选填)
- **THEN** SHALL 读取 `dbname`(数据库名,必填)
#### Scenario: SQLite 模式忽略 MySQL 参数
- **WHEN** `database.driver``sqlite`
- **THEN** SHALL 忽略 MySQL 相关配置字段host/port/user/password/dbname
- **THEN** SHALL 仅使用 `path` 字段作为数据库文件路径
#### Scenario: MySQL 模式忽略 SQLite 参数
- **WHEN** `database.driver``mysql`
- **THEN** SHALL 忽略 `path` 字段
### Requirement: 数据库初始化公共包
系统 SHALL 提供 `internal/database` 公共包,封装数据库初始化、迁移执行和连接关闭逻辑,供 `cmd/server``cmd/desktop` 共同调用。
#### Scenario: 公共包 Init 函数
- **WHEN** 调用 `database.Init(cfg, logger)`
- **THEN** SHALL 根据 `cfg.Driver` 选择对应的 GORM 驱动打开连接
- **THEN** SHALL 执行对应方言的 goose 迁移
- **THEN** SHALL 配置连接池参数
- **THEN** SHALL 在 `driver=sqlite` 时执行 `PRAGMA journal_mode=WAL`
- **THEN** SHALL 在 `driver=mysql` 时跳过 SQLite 专有 PRAGMA
- **THEN** SHALL 返回 `*gorm.DB` 实例
#### Scenario: 公共包 Close 函数
- **WHEN** 调用 `database.Close(db)`
- **THEN** SHALL 获取底层 `sql.DB` 并关闭连接
#### Scenario: 迁移目录选择
- **WHEN** 执行迁移
- **THEN** SHALL 在 `driver=sqlite` 时使用 `migrations/sqlite/` 目录goose dialect 为 `sqlite3`
- **THEN** SHALL 在 `driver=mysql` 时使用 `migrations/mysql/` 目录goose dialect 为 `mysql`
### Requirement: MySQL 方言迁移文件
系统 SHALL 提供 MySQL 方言的初始迁移文件 `migrations/mysql/20260421000001_initial_schema.sql`
#### Scenario: providers 表
- **WHEN** 执行 MySQL 初始迁移
- **THEN** SHALL 创建 `providers` 表,字段:`id VARCHAR(36) PRIMARY KEY``name VARCHAR(255) NOT NULL``api_key VARCHAR(255) NOT NULL``base_url VARCHAR(255) NOT NULL``protocol VARCHAR(50) DEFAULT 'openai'``enabled BOOLEAN DEFAULT TRUE``created_at DATETIME(3)``updated_at DATETIME(3)`
#### Scenario: models 表
- **WHEN** 执行 MySQL 初始迁移
- **THEN** SHALL 创建 `models` 表,字段:`id VARCHAR(36) PRIMARY KEY``provider_id VARCHAR(36) NOT NULL``model_name VARCHAR(255) NOT NULL``enabled BOOLEAN DEFAULT TRUE``created_at DATETIME(3)`
- **THEN** SHALL 创建 `UNIQUE(provider_id, model_name)` 约束
- **THEN** SHALL 创建 `FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE` 约束
- **THEN** SHALL 创建 `idx_models_provider_id``idx_models_model_name` 索引
#### Scenario: usage_stats 表
- **WHEN** 执行 MySQL 初始迁移
- **THEN** SHALL 创建 `usage_stats` 表,字段:`id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY``provider_id VARCHAR(36) NOT NULL``model_name VARCHAR(255) NOT NULL``request_count INT DEFAULT 0``date DATE NOT NULL`
- **THEN** SHALL 创建 `UNIQUE(provider_id, model_name, date)` 约束
- **THEN** SHALL 创建 `idx_usage_stats_provider_model_date` 复合索引
#### Scenario: Down 迁移
- **WHEN** 执行 MySQL down 迁移
- **THEN** SHALL 按正确顺序删除索引和表usage_stats → models → providers

View File

@@ -0,0 +1,104 @@
# MySQL Testing
## Purpose
提供 MySQL 数据库专项测试能力,验证迁移正确性、外键约束、并发写入等数据库特定行为。
## Requirements
### Requirement: MySQL 测试环境可启动
系统 SHALL 提供 Docker Compose 配置以启动 MySQL 8.0 测试环境。
#### Scenario: 启动 MySQL 测试容器
- **WHEN** 执行 `make test-mysql-up`
- **THEN** 启动 MySQL 8.0 容器,端口 13306
- **AND** 创建数据库 `nex_test`
- **AND** 容器数据存储在内存盘tmpfs
#### Scenario: 销毁 MySQL 测试容器
- **WHEN** 执行 `make test-mysql-down`
- **THEN** 停止并删除容器
- **AND** 所有数据被销毁
### Requirement: MySQL 测试可通过 build tag 控制
MySQL 测试 SHALL 使用 `// +build mysql` build tag默认不运行。
#### Scenario: 默认测试不包含 MySQL 测试
- **WHEN** 执行 `go test ./...`
- **THEN** 不运行 `tests/mysql/` 下的测试
#### Scenario: 启用 MySQL 测试
- **WHEN** 执行 `go test -tags=mysql ./tests/mysql/...`
- **THEN** 运行所有 MySQL 测试
### Requirement: MySQL 迁移正确执行
MySQL 测试 SHALL 验证迁移脚本在 MySQL 环境下正确执行。
#### Scenario: 迁移创建所有表
- **WHEN** 运行 MySQL 迁移
- **THEN** 创建 `providers``models``usage_stats`
- **AND** 字段类型符合 MySQL 迁移文件定义VARCHAR、DATETIME(3)、BOOLEAN 等)
- **AND** 索引 `idx_models_provider_id``idx_models_model_name``idx_usage_stats_provider_model_date` 创建成功
#### Scenario: 迁移可重复执行
- **WHEN** 在已迁移的数据库上再次运行迁移
- **THEN** 不报错,数据库状态不变
### Requirement: MySQL 外键约束生效
MySQL 测试 SHALL 验证外键约束行为符合预期。
#### Scenario: 外键约束阻止无效引用
- **WHEN** 创建 model 时 `provider_id` 不存在
- **THEN** 操作失败,返回外键约束错误
#### Scenario: 级联删除生效
- **WHEN** 删除 provider
- **THEN** 该 provider 的所有 models 被级联删除
### Requirement: MySQL UNIQUE 约束生效
MySQL 测试 SHALL 验证 UNIQUE 约束行为符合预期。
#### Scenario: models 表 UNIQUE 约束
- **WHEN** 尝试创建相同 `(provider_id, model_name)` 组合的 model
- **THEN** 操作失败,返回唯一约束错误
#### Scenario: usage_stats 表 UNIQUE 约束
- **WHEN** 尝试创建相同 `(provider_id, model_name, date)` 组合的 usage_stats
- **THEN** 操作失败,返回唯一约束错误
### Requirement: MySQL 并发写入正确
MySQL 测试 SHALL 验证并发写入不丢失数据。
#### Scenario: 并发记录 usage_stats
- **WHEN** 10 个 goroutine 并发调用 `statsRepo.Record(providerID, modelName)`
- **THEN** 最终 `request_count` 等于 10
- **AND** 无数据丢失或重复
#### Scenario: 并发创建相同 provider
- **WHEN** 10 个 goroutine 并发创建相同 ID 的 provider
- **THEN** 仅 1 个成功,其他 9 个失败
#### Scenario: 并发创建相同 model
- **WHEN** 10 个 goroutine 并发创建相同 `(provider_id, model_name)` 的 model
- **THEN** 仅 1 个成功,其他 9 个失败
### Requirement: MySQL 测试命令完整
Makefile SHALL 提供完整的 MySQL 测试命令。
#### Scenario: 完整测试流程
- **WHEN** 执行 `make test-mysql`
- **THEN** 启动 Docker MySQL
- **AND** 等待 MySQL 就绪
- **AND** 运行所有 MySQL 测试
- **AND** 销毁容器
#### Scenario: 快速测试(容器已运行)
- **WHEN** 执行 `make test-mysql-quick`
- **THEN** 直接运行测试,不管理容器生命周期

View File

@@ -93,7 +93,21 @@
- **WHEN** 同时处理多个并发请求
- **THEN** 网关 SHALL 使用原子操作正确增加每个请求的计数
- **THEN** 不 SHALL 因并发写入而丢失统计
- **THEN** SHALL 使用 StatsBuffer 的内存计数器
- **THEN** SHALL 使用 upsert 操作保证原子性
#### Scenario: 并发调用 Record 方法
- **WHEN** 多个 goroutine 并发调用 StatsRepository.Record
- **THEN** SHALL 使用 INSERT ... ON DUPLICATE KEY UPDATE (MySQL) 或 INSERT ... ON CONFLICT DO UPDATE (SQLite)
- **THEN** SHALL 保证所有并发调用的计数都被正确累加
- **THEN** 不 SHALL 因 UNIQUE 约束冲突而丢失数据
#### Scenario: 并发调用 BatchUpdate 方法
- **WHEN** 多个 goroutine 并发调用 StatsRepository.BatchUpdate
- **THEN** SHALL 使用 upsert 操作保证原子性
- **THEN** SHALL 正确累加所有 delta 值
- **THEN** 不 SHALL 因并发写入而丢失统计
### Requirement: 使用 service 层处理业务逻辑
@@ -125,14 +139,14 @@ Service SHALL 通过 StatsRepository 访问数据。
- **WHEN** StatsBuffer 刷新统计
- **THEN** SHALL 调用 StatsRepository.BatchUpdate
- **THEN** SHALL 使用事务更新或创建统计记录
- **THEN** SHALL 使用 upsert 操作更新或创建统计记录
- **THEN** SHALL 支持增量更新request_count + delta
#### Scenario: 事务处理
#### Scenario: upsert 操作
- **WHEN** 记录统计
- **THEN** SHALL 在 repository 层使用数据库事务
- **THEN** SHALL 保并发安全
- **THEN** SHALL 在 repository 层使用 upsert 操作
- **THEN** SHALL 保证原子性和并发安全
### Requirement: 统计查询优化
@@ -168,11 +182,18 @@ StatsRepository SHALL 新增 BatchUpdate 方法支持批量增量更新。
#### Scenario: BatchUpdate 更新现有记录
- **WHEN** 调用 BatchUpdate 且当日记录存在
- **THEN** SHALL 使用事务更新 request_count = request_count + delta
- **THEN** SHALL 使用 upsert 操作更新 request_count = request_count + delta
- **THEN** SHALL 保证原子性,无竞态条件
- **THEN** SHALL 不创建新记录
#### Scenario: BatchUpdate 创建新记录
- **WHEN** 调用 BatchUpdate 且当日记录不存在
- **THEN** SHALL 创建新记录request_count = delta
- **THEN** SHALL 使用事务保证原子性
- **THEN** SHALL 使用 upsert 操作保证原子性
#### Scenario: BatchUpdate 并发安全
- **WHEN** 多个 BatchUpdate 调用同时执行
- **THEN** SHALL 保证所有 delta 都被正确累加
- **THEN** SHALL 不因并发冲突而丢失数据