diff --git a/Makefile b/Makefile index 17c26dc..04da5e8 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,7 @@ backend-build backend-run backend-dev backend-test backend-test-unit backend-test-integration backend-test-coverage \ backend-lint backend-clean backend-deps backend-generate \ backend-db-up backend-db-down backend-db-status backend-db-create \ + test-mysql-up test-mysql-down test-mysql test-mysql-quick \ frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \ desktop-build desktop-build-mac desktop-build-win desktop-build-linux \ desktop-dev desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \ @@ -66,17 +67,53 @@ backend-generate: cd backend && go generate ./... backend-db-up: - cd backend && goose -dir migrations sqlite3 $(DB_PATH) up + @echo "Running database migration up..." + cd backend && goose -dir migrations/sqlite3 sqlite3 "$(DB_PATH)" up backend-db-down: - cd backend && goose -dir migrations sqlite3 $(DB_PATH) down + @echo "Running database migration down..." + cd backend && goose -dir migrations/sqlite3 sqlite3 "$(DB_PATH)" down backend-db-status: - cd backend && goose -dir migrations sqlite3 $(DB_PATH) status + @echo "Checking database migration status..." + cd backend && goose -dir migrations/sqlite3 sqlite3 "$(DB_PATH)" status backend-db-create: @read -p "Migration name: " name; \ - cd backend && goose -dir migrations create $$name sql + cd backend && goose -dir migrations/sqlite create $$name sql; \ + cd backend && goose -dir migrations/mysql create $$name sql + +# ============================================ +# MySQL 专项测试 +# ============================================ + +test-mysql-up: + @echo "Starting MySQL test container..." + cd backend/tests/mysql && docker-compose up -d + @echo "Waiting for MySQL to be ready..." + @for i in $$(seq 1 30); do \ + if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \ + echo "MySQL is ready!"; \ + exit 0; \ + fi; \ + echo "Waiting... ($$i/30)"; \ + sleep 1; \ + done; \ + echo "MySQL failed to start"; \ + exit 1 + +test-mysql-down: + @echo "Stopping MySQL test container..." + cd backend/tests/mysql && docker-compose down -v + +test-mysql: test-mysql-up + @echo "Running MySQL tests..." + cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1 + $(MAKE) test-mysql-down + +test-mysql-quick: + @echo "Running MySQL tests (without container management)..." + cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1 # ============================================ # 前端 diff --git a/README.md b/README.md index 64793cb..09b18bb 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ nex/ - **语言**: Go 1.26+ - **HTTP 框架**: Gin - **ORM**: GORM -- **数据库**: SQLite +- **数据库**: SQLite / MySQL - **日志**: zap + lumberjack(结构化日志 + 日志轮转) - **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值) - **验证**: go-playground/validator/v10 @@ -204,7 +204,14 @@ server: write_timeout: 30s database: - path: ~/.nex/config.db + driver: sqlite # sqlite 或 mysql + path: ~/.nex/config.db # SQLite 数据库文件路径 + # --- MySQL 配置(driver=mysql 时生效)--- + # host: localhost + # port: 3306 + # user: nex + # password: "" + # dbname: nex max_idle_conns: 10 max_open_conns: 100 conn_max_lifetime: 1h @@ -226,6 +233,14 @@ log: export NEX_SERVER_PORT=9000 export NEX_DATABASE_PATH=/data/nex.db export NEX_LOG_LEVEL=debug + +# MySQL 模式 +export NEX_DATABASE_DRIVER=mysql +export NEX_DATABASE_HOST=db.example.com +export NEX_DATABASE_PORT=3306 +export NEX_DATABASE_USER=nex +export NEX_DATABASE_PASSWORD=secret +export NEX_DATABASE_DBNAME=nex ``` 命名规则:配置路径转大写 + 下划线(如 `server.port` → `NEX_SERVER_PORT`)。 @@ -241,7 +256,7 @@ export NEX_LOG_LEVEL=debug ### 数据文件 - `~/.nex/config.yaml` - 配置文件 -- `~/.nex/config.db` - SQLite 数据库 +- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件) - `~/.nex/log/` - 日志目录 ## 测试 diff --git a/backend/README.md b/backend/README.md index 4041b3e..fc7f136 100644 --- a/backend/README.md +++ b/backend/README.md @@ -24,7 +24,7 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。 - **语言**: Go 1.26+ - **HTTP 框架**: Gin - **ORM**: GORM -- **数据库**: SQLite +- **数据库**: SQLite / MySQL - **日志**: zap + lumberjack - **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值) - **验证**: go-playground/validator/v10 @@ -294,7 +294,14 @@ server: write_timeout: 30s database: - path: ~/.nex/config.db + driver: sqlite # sqlite 或 mysql + path: ~/.nex/config.db # SQLite 数据库文件路径 + # --- MySQL 配置(driver=mysql 时生效)--- + # host: localhost + # port: 3306 + # user: nex + # password: "" + # dbname: nex max_idle_conns: 10 max_open_conns: 100 conn_max_lifetime: 1h @@ -316,6 +323,14 @@ log: export NEX_SERVER_PORT=9000 export NEX_DATABASE_PATH=/data/nex.db export NEX_LOG_LEVEL=debug + +# MySQL 模式 +export NEX_DATABASE_DRIVER=mysql +export NEX_DATABASE_HOST=db.example.com +export NEX_DATABASE_PORT=3306 +export NEX_DATABASE_USER=nex +export NEX_DATABASE_PASSWORD=secret +export NEX_DATABASE_DBNAME=nex ``` 命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port` → `NEX_SERVER_PORT`)。 @@ -332,7 +347,7 @@ export NEX_LOG_LEVEL=debug ``` 服务器: --server-port, --server-read-timeout, --server-write-timeout -数据库: --database-path, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime +数据库: --database-driver, --database-path, --database-host, --database-port, --database-user, --database-password, --database-dbname, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime 日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress 通用: --config (指定配置文件路径) ``` @@ -352,15 +367,20 @@ export NEX_LOG_LEVEL=debug # Docker 部署 docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server +# MySQL 模式 +./server --database-driver mysql --database-host db.example.com --database-user nex --database-password secret --database-dbname nex + # 自定义配置文件 ./server --config /path/to/custom.yaml ``` 数据文件: - `~/.nex/config.yaml` - 配置文件 -- `~/.nex/config.db` - SQLite 数据库 +- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件) - `~/.nex/log/` - 日志目录 +**MySQL 连接说明**:MySQL 连接使用 DSN 格式: `user:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=true&loc=Local`,最低支持 MySQL 8.0+。 + ## 测试 ```bash diff --git a/backend/cmd/desktop/main.go b/backend/cmd/desktop/main.go index f43832d..aabd816 100644 --- a/backend/cmd/desktop/main.go +++ b/backend/cmd/desktop/main.go @@ -17,16 +17,13 @@ import ( "github.com/getlantern/systray" "github.com/gin-gonic/gin" "github.com/gofrs/flock" - "github.com/pressly/goose/v3" "go.uber.org/zap" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" "nex/backend/internal/config" "nex/backend/internal/conversion" "nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/openai" + "nex/backend/internal/database" "nex/backend/internal/handler" "nex/backend/internal/handler/middleware" "nex/backend/internal/provider" @@ -79,12 +76,12 @@ func main() { } defer zapLogger.Sync() - db, err := initDatabase(cfg) + db, err := database.Init(&cfg.Database, zapLogger) if err != nil { showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err)) os.Exit(1) } - defer closeDB(db) + defer database.Close(db) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) @@ -159,76 +156,6 @@ func main() { setupSystray(port) } -func initDatabase(cfg *config.Config) (*gorm.DB, error) { - dbDir := filepath.Dir(cfg.Database.Path) - if err := os.MkdirAll(dbDir, 0755); err != nil { - return nil, fmt.Errorf("创建数据库目录失败: %w", err) - } - - db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Info), - }) - if err != nil { - return nil, err - } - - if err := runMigrations(db); err != nil { - return nil, fmt.Errorf("数据库迁移失败: %w", err) - } - - if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { - log.Printf("警告: 启用 WAL 模式失败: %v", err) - } - - sqlDB, err := db.DB() - if err != nil { - return nil, err - } - sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns) - sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns) - sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime) - - return db, nil -} - -func runMigrations(db *gorm.DB) error { - sqlDB, err := db.DB() - if err != nil { - return err - } - - migrationsDir := getMigrationsDir() - if _, err := os.Stat(migrationsDir); os.IsNotExist(err) { - return fmt.Errorf("迁移目录不存在: %s", migrationsDir) - } - - goose.SetDialect("sqlite3") - if err := goose.Up(sqlDB, migrationsDir); err != nil { - return err - } - - return nil -} - -func getMigrationsDir() string { - _, filename, _, ok := runtime.Caller(0) - if ok { - dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations") - if abs, err := filepath.Abs(dir); err == nil { - return abs - } - } - return "./migrations" -} - -func closeDB(db *gorm.DB) { - sqlDB, err := db.DB() - if err != nil { - return - } - sqlDB.Close() -} - func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { r.Any("/v1/*path", proxyHandler.HandleProxy) diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index a604401..2bf8008 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -7,22 +7,17 @@ import ( "net/http" "os" "os/signal" - "path/filepath" - "runtime" "syscall" "time" "github.com/gin-gonic/gin" - "github.com/pressly/goose/v3" "go.uber.org/zap" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" "nex/backend/internal/config" "nex/backend/internal/conversion" "nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/openai" + "nex/backend/internal/database" "nex/backend/internal/handler" "nex/backend/internal/handler/middleware" "nex/backend/internal/provider" @@ -32,16 +27,13 @@ import ( ) func main() { - // 1. 加载配置(已包含 CLI 参数解析、环境变量绑定、配置文件读取和验证) cfg, err := config.LoadConfig() if err != nil { log.Fatalf("加载配置失败: %v", err) } - // 2. 打印配置摘要 cfg.PrintSummary() - // 3. 初始化日志 zapLogger, err := pkgLogger.New(pkgLogger.Config{ Level: cfg.Log.Level, Path: cfg.Log.Path, @@ -55,37 +47,31 @@ func main() { } defer zapLogger.Sync() - // 3. 初始化数据库 - db, err := initDatabase(cfg) + db, err := database.Init(&cfg.Database, zapLogger) if err != nil { zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error())) } - defer closeDB(db) + defer database.Close(db) - // 4. 初始化 repository 层 providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - // 5. 初始化缓存 routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger) if err := routingCache.Preload(); err != nil { zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err)) } - // 6. 初始化统计缓冲 statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger, service.WithFlushInterval(5*time.Second), service.WithFlushThreshold(100)) statsBuffer.Start() - // 7. 初始化 service 层 providerService := service.NewProviderService(providerRepo, modelRepo, routingCache) modelService := service.NewModelService(modelRepo, providerRepo, routingCache) routingService := service.NewRoutingService(routingCache) statsService := service.NewStatsService(statsRepo, statsBuffer) - // 8. 创建 ConversionEngine registry := conversion.NewMemoryRegistry() if err := registry.Register(openai.NewAdapter()); err != nil { zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error())) @@ -95,16 +81,13 @@ func main() { } engine := conversion.NewConversionEngine(registry, zapLogger) - // 9. 初始化 provider client providerClient := provider.NewClient() - // 10. 初始化 handler 层 proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) providerHandler := handler.NewProviderHandler(providerService) modelHandler := handler.NewModelHandler(modelService) statsHandler := handler.NewStatsHandler(statsService) - // 11. 创建 Gin 引擎 gin.SetMode(gin.ReleaseMode) r := gin.New() @@ -115,9 +98,8 @@ func main() { setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler) - // 12. 启动服务器 srv := &http.Server{ - Addr: formatAddr(cfg.Server.Port), + Addr: fmt.Sprintf(":%d", cfg.Server.Port), Handler: r, ReadTimeout: cfg.Server.ReadTimeout, WriteTimeout: cfg.Server.WriteTimeout, @@ -148,83 +130,9 @@ func main() { zapLogger.Info("服务器已关闭") } -func initDatabase(cfg *config.Config) (*gorm.DB, error) { - db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Info), - }) - if err != nil { - return nil, err - } - - if err := runMigrations(db); err != nil { - return nil, fmt.Errorf("数据库迁移失败: %w", err) - } - - if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { - log.Printf("警告: 启用 WAL 模式失败: %v", err) - } - - sqlDB, err := db.DB() - if err != nil { - return nil, err - } - sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns) - sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns) - sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime) - - log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v", - cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime) - - return db, nil -} - -func runMigrations(db *gorm.DB) error { - sqlDB, err := db.DB() - if err != nil { - return err - } - - migrationsDir := getMigrationsDir() - if _, err := os.Stat(migrationsDir); os.IsNotExist(err) { - return fmt.Errorf("迁移目录不存在: %s", migrationsDir) - } - - goose.SetDialect("sqlite3") - if err := goose.Up(sqlDB, migrationsDir); err != nil { - return err - } - - return nil -} - -func getMigrationsDir() string { - _, filename, _, ok := runtime.Caller(0) - if ok { - dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations") - if abs, err := filepath.Abs(dir); err == nil { - return abs - } - } - return "./migrations" -} - -func closeDB(db *gorm.DB) { - sqlDB, err := db.DB() - if err != nil { - return - } - sqlDB.Close() -} - -func formatAddr(port int) string { - return fmt.Sprintf(":%d", port) -} - func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { - // 统一代理入口: /{protocol}/{path} r.Any("/:protocol/*path", proxyHandler.HandleProxy) - // 供应商管理 API providers := r.Group("/api/providers") { providers.GET("", providerHandler.ListProviders) @@ -234,7 +142,6 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand providers.DELETE("/:id", providerHandler.DeleteProvider) } - // 模型管理 API models := r.Group("/api/models") { models.GET("", modelHandler.ListModels) @@ -244,14 +151,12 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand models.DELETE("/:id", modelHandler.DeleteModel) } - // 统计查询 API stats := r.Group("/api/stats") { stats.GET("", statsHandler.GetStats) stats.GET("/aggregate", statsHandler.AggregateStats) } - // 健康检查 r.GET("/health", func(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) }) diff --git a/backend/go.mod b/backend/go.mod index 7ce3c5f..f56da25 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -24,6 +24,7 @@ require ( go.uber.org/zap v1.27.1 gopkg.in/lumberjack.v2 v2.0.0 gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/mysql v1.6.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.31.1 nex/embedfs v0.0.0-00010101000000-000000000000 @@ -32,6 +33,7 @@ require ( require ( 4d63.com/gocheckcompilerdirectives v1.3.0 // indirect 4d63.com/gochecknoglobals v0.2.2 // indirect + filippo.io/edwards25519 v1.2.0 // indirect github.com/4meepo/tagalign v1.4.2 // indirect github.com/Abirdcfly/dupword v0.1.3 // indirect github.com/Antonboom/errname v1.0.0 // indirect @@ -90,6 +92,7 @@ require ( github.com/go-critic/go-critic v0.12.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-stack/stack v1.8.0 // indirect github.com/go-toolsmith/astcast v1.1.0 // indirect github.com/go-toolsmith/astcopy v1.1.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 9c89ceb..c99656f 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -35,6 +35,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/4meepo/tagalign v1.4.2 h1:0hcLHPGMjDyM1gHG58cS73aQF8J4TdVR96TZViorO9E= github.com/4meepo/tagalign v1.4.2/go.mod h1:+p4aMyFM+ra7nb41CnFG6aSDXqRxU/w1VQqScKqDARI= github.com/Abirdcfly/dupword v0.1.3 h1:9Pa1NuAsZvpFPi9Pqkd93I7LIYRURj+A//dFd5tgBeE= @@ -206,6 +208,8 @@ github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= @@ -1052,6 +1056,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= +gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 8bd13a0..391535d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -32,7 +32,13 @@ type ServerConfig struct { // DatabaseConfig 数据库配置 type DatabaseConfig struct { - Path string `yaml:"path" mapstructure:"path" validate:"required"` + Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"` + Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"` + Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"` + Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"` + User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"` + Password string `yaml:"password" mapstructure:"password"` + DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"` MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"` MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"` ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"` @@ -61,7 +67,13 @@ func DefaultConfig() *Config { WriteTimeout: 30 * time.Second, }, Database: DatabaseConfig{ + Driver: "sqlite", Path: filepath.Join(nexDir, "config.db"), + Host: "", + Port: 3306, + User: "", + Password: "", + DBName: "nex", MaxIdleConns: 10, MaxOpenConns: 100, ConnMaxLifetime: 1 * time.Hour, @@ -117,7 +129,13 @@ func setupDefaults(v *viper.Viper) { v.SetDefault("server.read_timeout", "30s") v.SetDefault("server.write_timeout", "30s") + v.SetDefault("database.driver", "sqlite") v.SetDefault("database.path", filepath.Join(nexDir, "config.db")) + v.SetDefault("database.host", "") + v.SetDefault("database.port", 3306) + v.SetDefault("database.user", "") + v.SetDefault("database.password", "") + v.SetDefault("database.dbname", "nex") v.SetDefault("database.max_idle_conns", 10) v.SetDefault("database.max_open_conns", 100) v.SetDefault("database.conn_max_lifetime", "1h") @@ -138,7 +156,13 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) { flagSet.Duration("server-read-timeout", 0, "读超时") flagSet.Duration("server-write-timeout", 0, "写超时") + flagSet.String("database-driver", "", "数据库驱动:sqlite/mysql") flagSet.String("database-path", "", "数据库文件路径") + flagSet.String("database-host", "", "MySQL 主机地址") + flagSet.Int("database-port", 0, "MySQL 端口") + flagSet.String("database-user", "", "MySQL 用户名") + flagSet.String("database-password", "", "MySQL 密码") + flagSet.String("database-dbname", "", "MySQL 数据库名") flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数") flagSet.Int("database-max-open-conns", 0, "最大打开连接数") flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期") @@ -156,7 +180,13 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) { v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout")) v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout")) + v.BindPFlag("database.driver", flagSet.Lookup("database-driver")) v.BindPFlag("database.path", flagSet.Lookup("database-path")) + v.BindPFlag("database.host", flagSet.Lookup("database-host")) + v.BindPFlag("database.port", flagSet.Lookup("database-port")) + v.BindPFlag("database.user", flagSet.Lookup("database-user")) + v.BindPFlag("database.password", flagSet.Lookup("database-password")) + v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname")) v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns")) v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns")) v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime")) @@ -268,7 +298,7 @@ func SaveConfig(cfg *Config) error { return appErrors.Wrap(appErrors.ErrInternal, err) } - return os.WriteFile(configPath, data, 0644) + return os.WriteFile(configPath, data, 0600) } // Validate validates the config @@ -285,7 +315,13 @@ func (c *Config) PrintSummary() { fmt.Println("\nAI Gateway 启动配置") fmt.Println("==================") fmt.Printf("服务器端口: %d\n", c.Server.Port) - fmt.Printf("数据库路径: %s\n", c.Database.Path) + if c.Database.Driver == "mysql" { + fmt.Printf("数据库类型: mysql\n") + fmt.Printf("数据库地址: %s:%d/%s\n", c.Database.Host, c.Database.Port, c.Database.DBName) + } else { + fmt.Printf("数据库类型: sqlite\n") + fmt.Printf("数据库路径: %s\n", c.Database.Path) + } fmt.Printf("日志级别: %s\n", c.Log.Level) fmt.Println("\n配置来源:") configPath, _ := GetConfigPath() diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index cc390bc..b4a58e4 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -19,6 +19,12 @@ func TestDefaultConfig(t *testing.T) { assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout) assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout) + assert.Equal(t, "sqlite", cfg.Database.Driver) + assert.Equal(t, "", cfg.Database.Host) + assert.Equal(t, 3306, cfg.Database.Port) + assert.Equal(t, "", cfg.Database.User) + assert.Equal(t, "", cfg.Database.Password) + assert.Equal(t, "nex", cfg.Database.DBName) assert.Equal(t, 10, cfg.Database.MaxIdleConns) assert.Equal(t, 100, cfg.Database.MaxOpenConns) assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime) @@ -86,11 +92,76 @@ func TestConfig_Validate(t *testing.T) { wantErr: false, }, { - name: "数据库路径为空无效", + name: "SQLite模式路径为空无效", modify: func(c *Config) { c.Database.Path = "" }, wantErr: true, errMsg: "配置验证失败", }, + { + name: "driver值不合法", + modify: func(c *Config) { c.Database.Driver = "postgres" }, + wantErr: true, + errMsg: "配置验证失败", + }, + { + name: "MySQL配置有效", + modify: func(c *Config) { + c.Database.Driver = "mysql" + c.Database.Host = "localhost" + c.Database.Port = 3306 + c.Database.User = "root" + c.Database.DBName = "nex" + c.Database.Path = "" + }, + wantErr: false, + }, + { + name: "MySQL模式host为空无效", + modify: func(c *Config) { + c.Database.Driver = "mysql" + c.Database.Host = "" + c.Database.User = "root" + c.Database.DBName = "nex" + c.Database.Path = "" + }, + wantErr: true, + errMsg: "配置验证失败", + }, + { + name: "MySQL模式user为空无效", + modify: func(c *Config) { + c.Database.Driver = "mysql" + c.Database.Host = "localhost" + c.Database.User = "" + c.Database.DBName = "nex" + c.Database.Path = "" + }, + wantErr: true, + errMsg: "配置验证失败", + }, + { + name: "MySQL模式dbname为空无效", + modify: func(c *Config) { + c.Database.Driver = "mysql" + c.Database.Host = "localhost" + c.Database.User = "root" + c.Database.DBName = "" + c.Database.Path = "" + }, + wantErr: true, + errMsg: "配置验证失败", + }, + { + name: "MySQL模式忽略path字段", + modify: func(c *Config) { + c.Database.Driver = "mysql" + c.Database.Host = "localhost" + c.Database.User = "root" + c.Database.DBName = "nex" + c.Database.Path = "" + }, + wantErr: false, + }, } for _, tt := range tests { @@ -140,7 +211,10 @@ func TestSaveAndLoadConfig(t *testing.T) { WriteTimeout: 20 * time.Second, }, Database: DatabaseConfig{ + Driver: "sqlite", Path: filepath.Join(dir, "test.db"), + Port: 3306, + DBName: "nex", MaxIdleConns: 5, MaxOpenConns: 50, ConnMaxLifetime: 30 * time.Minute, @@ -210,6 +284,9 @@ func TestConfigPriority(t *testing.T) { assert.Equal(t, 9826, cfg.Server.Port) assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout) assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout) + assert.Equal(t, "sqlite", cfg.Database.Driver) + assert.Equal(t, 3306, cfg.Database.Port) + assert.Equal(t, "nex", cfg.Database.DBName) assert.Equal(t, 10, cfg.Database.MaxIdleConns) assert.Equal(t, 100, cfg.Database.MaxOpenConns) assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime) @@ -222,11 +299,19 @@ func TestConfigPriority(t *testing.T) { } func TestPrintSummary(t *testing.T) { - // 测试配置摘要输出 - t.Run("打印配置摘要", func(t *testing.T) { + t.Run("SQLite模式摘要", func(t *testing.T) { cfg := DefaultConfig() - // PrintSummary 只是打印,不会返回错误 - // 这里主要验证不会 panic + assert.NotPanics(t, func() { + cfg.PrintSummary() + }) + }) + t.Run("MySQL模式摘要", func(t *testing.T) { + cfg := DefaultConfig() + cfg.Database.Driver = "mysql" + cfg.Database.Host = "db.example.com" + cfg.Database.Port = 3306 + cfg.Database.User = "nex" + cfg.Database.DBName = "nex" assert.NotPanics(t, func() { cfg.PrintSummary() }) diff --git a/backend/internal/config/models.go b/backend/internal/config/models.go index 628d184..8add81e 100644 --- a/backend/internal/config/models.go +++ b/backend/internal/config/models.go @@ -29,8 +29,8 @@ type Model struct { // UsageStats 用量统计 type UsageStats struct { ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - ProviderID string `gorm:"not null;index" json:"provider_id"` - ModelName string `gorm:"not null;index" json:"model_name"` + ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"` + ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"` RequestCount int `gorm:"default:0" json:"request_count"` Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"` } diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go new file mode 100644 index 0000000..76661a5 --- /dev/null +++ b/backend/internal/database/database.go @@ -0,0 +1,126 @@ +package database + +import ( + "fmt" + "log" + "os" + "path/filepath" + "runtime" + + "github.com/pressly/goose/v3" + "go.uber.org/zap" + "gorm.io/driver/mysql" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "nex/backend/internal/config" +) + +func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) { + db, err := initDB(cfg) + if err != nil { + return nil, fmt.Errorf("初始化数据库失败: %w", err) + } + + if err := runMigrations(db, cfg.Driver); err != nil { + return nil, fmt.Errorf("数据库迁移失败: %w", err) + } + + configurePool(db, cfg) + + return db, nil +} + +func Close(db *gorm.DB) { + sqlDB, err := db.DB() + if err != nil { + return + } + sqlDB.Close() +} + +func initDB(cfg *config.DatabaseConfig) (*gorm.DB, error) { + gormConfig := &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + } + + switch cfg.Driver { + case "mysql": + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", + cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName) + return gorm.Open(mysql.Open(dsn), gormConfig) + default: + dbDir := filepath.Dir(cfg.Path) + if err := os.MkdirAll(dbDir, 0755); err != nil { + return nil, fmt.Errorf("创建数据库目录失败: %w", err) + } + return gorm.Open(sqlite.Open(cfg.Path), gormConfig) + } +} + +func runMigrations(db *gorm.DB, driver string) error { + sqlDB, err := db.DB() + if err != nil { + return err + } + + migrationsDir := getMigrationsDir(driver) + if _, err := os.Stat(migrationsDir); os.IsNotExist(err) { + return fmt.Errorf("迁移目录不存在: %s", migrationsDir) + } + + gooseDialect := "sqlite3" + migrationsSubDir := "sqlite" + if driver == "mysql" { + gooseDialect = "mysql" + migrationsSubDir = "mysql" + } + + goose.SetDialect(gooseDialect) + if err := goose.Up(sqlDB, migrationsDir); err != nil { + return err + } + + log.Printf("使用 %s 方言执行迁移,目录: %s", gooseDialect, migrationsSubDir) + return nil +} + +func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) { + if cfg.Driver == "sqlite" { + if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { + log.Printf("警告: 启用 WAL 模式失败: %v", err) + } + } + + sqlDB, err := db.DB() + if err != nil { + return + } + sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) + sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) + sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime) + + log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v", + cfg.MaxIdleConns, cfg.MaxOpenConns, cfg.ConnMaxLifetime) +} + +func getMigrationsDir(driver string) string { + _, filename, _, ok := runtime.Caller(0) + if ok { + subDir := "sqlite" + if driver == "mysql" { + subDir = "mysql" + } + dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", subDir) + if abs, err := filepath.Abs(dir); err == nil { + return abs + } + } + return "./migrations" +} + +func BuildDSN(cfg *config.DatabaseConfig) string { + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", + cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName) +} diff --git a/backend/internal/database/database_test.go b/backend/internal/database/database_test.go new file mode 100644 index 0000000..ddb5669 --- /dev/null +++ b/backend/internal/database/database_test.go @@ -0,0 +1,75 @@ +package database + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "nex/backend/internal/config" +) + +func TestInit_SQLite(t *testing.T) { + dir := t.TempDir() + cfg := &config.DatabaseConfig{ + Driver: "sqlite", + Path: filepath.Join(dir, "test.db"), + MaxIdleConns: 5, + MaxOpenConns: 10, + ConnMaxLifetime: 0, + } + + db, err := Init(cfg, nil) + require.NoError(t, err) + require.NotNil(t, db) + defer Close(db) + + sqlDB, err := db.DB() + require.NoError(t, err) + require.NotNil(t, sqlDB) +} + +func TestClose(t *testing.T) { + dir := t.TempDir() + cfg := &config.DatabaseConfig{ + Driver: "sqlite", + Path: filepath.Join(dir, "test.db"), + MaxIdleConns: 5, + MaxOpenConns: 10, + ConnMaxLifetime: 0, + } + + db, err := Init(cfg, nil) + require.NoError(t, err) + require.NotNil(t, db) + + Close(db) +} + +func TestBuildDSN(t *testing.T) { + cfg := &config.DatabaseConfig{ + Driver: "mysql", + Host: "db.example.com", + Port: 3306, + User: "nexuser", + Password: "secretpass", + DBName: "nexdb", + } + + dsn := BuildDSN(cfg) + assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn) +} + +func TestBuildDSN_EmptyPassword(t *testing.T) { + cfg := &config.DatabaseConfig{ + Driver: "mysql", + Host: "localhost", + Port: 3306, + User: "root", + DBName: "nex", + } + + dsn := BuildDSN(cfg) + assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn) +} diff --git a/backend/internal/repository/stats_repo_impl.go b/backend/internal/repository/stats_repo_impl.go index 7692b33..ca9e7d8 100644 --- a/backend/internal/repository/stats_repo_impl.go +++ b/backend/internal/repository/stats_repo_impl.go @@ -1,10 +1,10 @@ package repository import ( - "errors" "time" "gorm.io/gorm" + "gorm.io/gorm/clause" "nex/backend/internal/config" "nex/backend/internal/domain" @@ -22,47 +22,43 @@ func (r *statsRepository) Record(providerID, modelName string) error { today := time.Now().Format("2006-01-02") todayTime, _ := time.Parse("2006-01-02", today) - return r.db.Transaction(func(tx *gorm.DB) error { - var stats config.UsageStats - err := tx.Where("provider_id = ? AND model_name = ? AND date = ?", - providerID, modelName, todayTime).First(&stats).Error + stats := config.UsageStats{ + ProviderID: providerID, + ModelName: modelName, + RequestCount: 1, + Date: todayTime, + } - if errors.Is(err, gorm.ErrRecordNotFound) { - stats = config.UsageStats{ - ProviderID: providerID, - ModelName: modelName, - RequestCount: 1, - Date: todayTime, - } - return tx.Create(&stats).Error - } else if err != nil { - return err - } - - return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error - }) + return r.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "provider_id"}, + {Name: "model_name"}, + {Name: "date"}, + }, + DoUpdates: clause.Assignments(map[string]interface{}{ + "request_count": gorm.Expr("request_count + 1"), + }), + }).Create(&stats).Error } func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error { - return r.db.Transaction(func(tx *gorm.DB) error { - var stats config.UsageStats - err := tx.Where("provider_id = ? AND model_name = ? AND date = ?", - providerID, modelName, date).First(&stats).Error + stats := config.UsageStats{ + ProviderID: providerID, + ModelName: modelName, + RequestCount: delta, + Date: date, + } - if errors.Is(err, gorm.ErrRecordNotFound) { - return tx.Create(&config.UsageStats{ - ProviderID: providerID, - ModelName: modelName, - RequestCount: delta, - Date: date, - }).Error - } else if err != nil { - return err - } - - return tx.Model(&stats). - Update("request_count", gorm.Expr("request_count + ?", delta)).Error - }) + return r.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "provider_id"}, + {Name: "model_name"}, + {Name: "date"}, + }, + DoUpdates: clause.Assignments(map[string]interface{}{ + "request_count": gorm.Expr("request_count + ?", delta), + }), + }).Create(&stats).Error } func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { diff --git a/backend/migrations/mysql/20260421000001_initial_schema.sql b/backend/migrations/mysql/20260421000001_initial_schema.sql new file mode 100644 index 0000000..85f64db --- /dev/null +++ b/backend/migrations/mysql/20260421000001_initial_schema.sql @@ -0,0 +1,44 @@ +-- +goose Up +-- MySQL 方言初始迁移:providers、models、usage_stats 完整表结构 + +CREATE TABLE IF NOT EXISTS providers ( + id VARCHAR(36) PRIMARY KEY, + name VARCHAR(255) NOT NULL, + api_key VARCHAR(255) NOT NULL, + base_url VARCHAR(255) NOT NULL, + protocol VARCHAR(50) DEFAULT 'openai', + enabled BOOLEAN DEFAULT TRUE, + created_at DATETIME(3), + updated_at DATETIME(3) +); + +CREATE TABLE IF NOT EXISTS models ( + id VARCHAR(36) PRIMARY KEY, + provider_id VARCHAR(36) NOT NULL, + model_name VARCHAR(255) NOT NULL, + enabled BOOLEAN DEFAULT TRUE, + created_at DATETIME(3), + FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE, + UNIQUE(provider_id, model_name) +); + +CREATE TABLE IF NOT EXISTS usage_stats ( + id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, + provider_id VARCHAR(36) NOT NULL, + model_name VARCHAR(255) NOT NULL, + request_count INT DEFAULT 0, + date DATE NOT NULL, + UNIQUE(provider_id, model_name, date) +); + +CREATE INDEX idx_models_provider_id ON models(provider_id); +CREATE INDEX idx_models_model_name ON models(model_name); +CREATE INDEX idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date); + +-- +goose Down +DROP INDEX IF EXISTS idx_usage_stats_provider_model_date ON usage_stats; +DROP INDEX IF EXISTS idx_models_model_name ON models; +DROP INDEX IF EXISTS idx_models_provider_id ON models; +DROP TABLE IF EXISTS usage_stats; +DROP TABLE IF EXISTS models; +DROP TABLE IF EXISTS providers; diff --git a/backend/migrations/20260421000001_initial_schema.sql b/backend/migrations/sqlite/20260421000001_initial_schema.sql similarity index 100% rename from backend/migrations/20260421000001_initial_schema.sql rename to backend/migrations/sqlite/20260421000001_initial_schema.sql diff --git a/backend/tests/config/config_test.go b/backend/tests/config/config_test.go index d15bd16..ff55aa1 100644 --- a/backend/tests/config/config_test.go +++ b/backend/tests/config/config_test.go @@ -158,7 +158,10 @@ func TestSaveAndLoadConfig(t *testing.T) { WriteTimeout: 45 * time.Second, }, Database: config.DatabaseConfig{ + Driver: "sqlite", Path: filepath.Join(tmpDir, "test.db"), + Port: 3306, + DBName: "nex", MaxIdleConns: 15, MaxOpenConns: 150, ConnMaxLifetime: 2 * time.Hour, diff --git a/backend/tests/mysql/concurrent_test.go b/backend/tests/mysql/concurrent_test.go new file mode 100644 index 0000000..8aa7ff0 --- /dev/null +++ b/backend/tests/mysql/concurrent_test.go @@ -0,0 +1,158 @@ +//go:build mysql + +package mysql + +import ( + "sync" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "nex/backend/internal/config" + "nex/backend/internal/repository" +) + +func TestConcurrent_UsageStatsRecord(t *testing.T) { + db := SetupMySQLTestDB(t) + + statsRepo := repository.NewStatsRepository(db) + + providerID := "concurrent-test-provider" + modelName := "gpt-4" + + concurrency := 10 + var wg sync.WaitGroup + wg.Add(concurrency) + + errChan := make(chan error, concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + err := statsRepo.Record(providerID, modelName) + if err != nil { + errChan <- err + } + }() + } + + wg.Wait() + close(errChan) + + var errorCount int + uniqueErrors := make(map[string]int) + for err := range errChan { + errorCount++ + uniqueErrors[err.Error()]++ + } + + t.Logf("并发 %d 次,错误 %d 次", concurrency, errorCount) + for errMsg, count := range uniqueErrors { + t.Logf(" 错误: %s (出现 %d 次)", errMsg, count) + } + + var stats config.UsageStats + err := db.Where("provider_id = ? AND model_name = ?", providerID, modelName). + First(&stats).Error + require.NoError(t, err, "应能查到 usage_stats 记录") + + successCount := concurrency - errorCount + t.Logf("成功次数: %d, 最终 request_count: %d", successCount, stats.RequestCount) + + assert.Equal(t, concurrency, stats.RequestCount, "request_count 应等于并发数,无数据丢失或重复") +} + +func TestConcurrent_ProviderCreate(t *testing.T) { + db := SetupMySQLTestDB(t) + + providerID := "concurrent-provider-id" + concurrency := 10 + + var wg sync.WaitGroup + wg.Add(concurrency) + + successCount := 0 + var mu sync.Mutex + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + provider := config.Provider{ + ID: providerID, + Name: "Concurrent Provider", + APIKey: "test-key", + BaseURL: "https://test.com", + Enabled: true, + } + + err := db.Create(&provider).Error + if err == nil { + mu.Lock() + successCount++ + mu.Unlock() + } + }() + } + + wg.Wait() + + assert.Equal(t, 1, successCount, "仅 1 个创建应成功") + + var count int64 + db.Model(&config.Provider{}).Where("id = ?", providerID).Count(&count) + assert.Equal(t, int64(1), count, "最终应有 1 条记录") +} + +func TestConcurrent_ModelCreate(t *testing.T) { + db := SetupMySQLTestDB(t) + + provider := config.Provider{ + ID: "concurrent-model-provider", + Name: "Test Provider", + APIKey: "test-key", + BaseURL: "https://test.com", + Enabled: true, + } + err := db.Create(&provider).Error + require.NoError(t, err, "创建 provider 应成功") + + modelName := "gpt-4-concurrent" + concurrency := 10 + + var wg sync.WaitGroup + wg.Add(concurrency) + + successCount := 0 + var mu sync.Mutex + + for i := 0; i < concurrency; i++ { + go func(idx int) { + defer wg.Done() + + model := config.Model{ + ID: uuid.New().String(), + ProviderID: provider.ID, + ModelName: modelName, + Enabled: true, + } + + err := db.Create(&model).Error + if err == nil { + mu.Lock() + successCount++ + mu.Unlock() + } + }(i) + } + + wg.Wait() + + assert.Equal(t, 1, successCount, "仅 1 个创建应成功") + + var count int64 + db.Model(&config.Model{}).Where("provider_id = ? AND model_name = ?", provider.ID, modelName).Count(&count) + assert.Equal(t, int64(1), count, "最终应有 1 条记录") +} diff --git a/backend/tests/mysql/constraint_test.go b/backend/tests/mysql/constraint_test.go new file mode 100644 index 0000000..7923c23 --- /dev/null +++ b/backend/tests/mysql/constraint_test.go @@ -0,0 +1,130 @@ +//go:build mysql + +package mysql + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + + "nex/backend/internal/config" +) + +func TestConstraint_ForeignKeyEnforced(t *testing.T) { + db := SetupMySQLTestDB(t) + + model := config.Model{ + ID: "test-model-id", + ProviderID: "non-existent-provider", + ModelName: "gpt-4", + Enabled: true, + } + + err := db.Create(&model).Error + assert.Error(t, err, "创建 model 时 provider_id 不存在应失败") + assert.Contains(t, err.Error(), "foreign key constraint", "错误应为外键约束错误") +} + +func TestConstraint_CascadeDelete(t *testing.T) { + db := SetupMySQLTestDB(t) + + provider := config.Provider{ + ID: "test-provider-cascade", + Name: "Test Provider", + APIKey: "test-key", + BaseURL: "https://test.com", + Enabled: true, + } + err := db.Create(&provider).Error + require.NoError(t, err, "创建 provider 应成功") + + model := config.Model{ + ID: "test-model-cascade", + ProviderID: provider.ID, + ModelName: "gpt-4", + Enabled: true, + } + err = db.Create(&model).Error + require.NoError(t, err, "创建 model 应成功") + + err = db.Delete(&provider).Error + require.NoError(t, err, "删除 provider 应成功") + + var count int64 + err = db.Model(&config.Model{}).Where("provider_id = ?", provider.ID).Count(&count).Error + require.NoError(t, err) + assert.Equal(t, int64(0), count, "删除 provider 后其 models 应被级联删除") +} + +func TestConstraint_UniqueProviderModel(t *testing.T) { + db := SetupMySQLTestDB(t) + + provider := config.Provider{ + ID: "test-provider-unique", + Name: "Test Provider", + APIKey: "test-key", + BaseURL: "https://test.com", + Enabled: true, + } + err := db.Create(&provider).Error + require.NoError(t, err, "创建 provider 应成功") + + model1 := config.Model{ + ID: "test-model-unique-1", + ProviderID: provider.ID, + ModelName: "gpt-4", + Enabled: true, + } + err = db.Create(&model1).Error + require.NoError(t, err, "创建第一个 model 应成功") + + model2 := config.Model{ + ID: "test-model-unique-2", + ProviderID: provider.ID, + ModelName: "gpt-4", + Enabled: true, + } + err = db.Create(&model2).Error + assert.Error(t, err, "创建相同 (provider_id, model_name) 的 model 应失败") + assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) || + (err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))), + "错误应为唯一约束错误") +} + +func TestConstraint_UniqueUsageStats(t *testing.T) { + db := SetupMySQLTestDB(t) + + today := time.Now().Format("2006-01-02") + todayTime, _ := time.Parse("2006-01-02", today) + + providerID := "test-provider-unique-stats" + + stats1 := config.UsageStats{ + ProviderID: providerID, + ModelName: "gpt-4", + RequestCount: 10, + Date: todayTime, + } + err := db.Create(&stats1).Error + require.NoError(t, err, "创建第一个 usage_stats 应成功") + + stats2 := config.UsageStats{ + ProviderID: providerID, + ModelName: "gpt-4", + RequestCount: 20, + Date: todayTime, + } + err = db.Create(&stats2).Error + assert.Error(t, err, "创建相同 (provider_id, model_name, date) 的 usage_stats 应失败") + assert.True(t, errors.Is(err, gorm.ErrDuplicatedKey) || + (err != nil && (err.Error() == "Error 1062" || containsDuplicateError(err.Error()))), + "错误应为唯一约束错误") +} + +func containsDuplicateError(errStr string) bool { + return len(errStr) > 0 && (errStr[0:8] == "Error 10" || errStr[0:5] == "Dupli") +} diff --git a/backend/tests/mysql/docker-compose.yml b/backend/tests/mysql/docker-compose.yml new file mode 100644 index 0000000..3735f24 --- /dev/null +++ b/backend/tests/mysql/docker-compose.yml @@ -0,0 +1,21 @@ +version: '3.8' + +services: + mysql: + image: mysql:8.0 + container_name: nex-mysql-test + environment: + MYSQL_ROOT_PASSWORD: testpass + MYSQL_DATABASE: nex_test + MYSQL_USER: nex_test + MYSQL_PASSWORD: testpass + ports: + - "13306:3306" + tmpfs: + - /var/lib/mysql + healthcheck: + test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-p$$MYSQL_ROOT_PASSWORD"] + interval: 1s + timeout: 5s + retries: 10 + start_period: 10s diff --git a/backend/tests/mysql/migration_test.go b/backend/tests/mysql/migration_test.go new file mode 100644 index 0000000..acb677c --- /dev/null +++ b/backend/tests/mysql/migration_test.go @@ -0,0 +1,126 @@ +//go:build mysql + +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMigration_TablesExist(t *testing.T) { + db := SetupMySQLTestDB(t) + + var tables []string + err := db.Raw("SHOW TABLES").Scan(&tables).Error + require.NoError(t, err) + + expectedTables := []string{"providers", "models", "usage_stats"} + for _, expected := range expectedTables { + assert.Contains(t, tables, expected, "表 %s 应存在", expected) + } +} + +func TestMigration_TableColumns(t *testing.T) { + db := SetupMySQLTestDB(t) + + t.Run("providers 表字段", func(t *testing.T) { + var columns []struct { + Field string + Type string + Null string + } + err := db.Raw("SHOW COLUMNS FROM providers").Scan(&columns).Error + require.NoError(t, err) + + columnMap := make(map[string]string) + for _, col := range columns { + columnMap[col.Field] = col.Type + } + + assert.Contains(t, columnMap["id"], "varchar", "id 应为 VARCHAR 类型") + assert.Contains(t, columnMap["name"], "varchar", "name 应为 VARCHAR 类型") + assert.Contains(t, columnMap["api_key"], "varchar", "api_key 应为 VARCHAR 类型") + assert.Contains(t, columnMap["base_url"], "varchar", "base_url 应为 VARCHAR 类型") + assert.Contains(t, columnMap["protocol"], "varchar", "protocol 应为 VARCHAR 类型") + assert.Contains(t, columnMap["enabled"], "tinyint", "enabled 应为 TINYINT (BOOLEAN) 类型") + assert.Contains(t, columnMap["created_at"], "datetime", "created_at 应为 DATETIME 类型") + assert.Contains(t, columnMap["updated_at"], "datetime", "updated_at 应为 DATETIME 类型") + }) + + t.Run("models 表字段", func(t *testing.T) { + var columns []struct { + Field string + Type string + } + err := db.Raw("SHOW COLUMNS FROM models").Scan(&columns).Error + require.NoError(t, err) + + columnMap := make(map[string]string) + for _, col := range columns { + columnMap[col.Field] = col.Type + } + + assert.Contains(t, columnMap["id"], "varchar", "id 应为 VARCHAR 类型") + assert.Contains(t, columnMap["provider_id"], "varchar", "provider_id 应为 VARCHAR 类型") + assert.Contains(t, columnMap["model_name"], "varchar", "model_name 应为 VARCHAR 类型") + assert.Contains(t, columnMap["enabled"], "tinyint", "enabled 应为 TINYINT (BOOLEAN) 类型") + assert.Contains(t, columnMap["created_at"], "datetime", "created_at 应为 DATETIME 类型") + }) + + t.Run("usage_stats 表字段", func(t *testing.T) { + var columns []struct { + Field string + Type string + } + err := db.Raw("SHOW COLUMNS FROM usage_stats").Scan(&columns).Error + require.NoError(t, err) + + columnMap := make(map[string]string) + for _, col := range columns { + columnMap[col.Field] = col.Type + } + + assert.Contains(t, columnMap["id"], "int", "id 应为 INT 类型") + assert.Contains(t, columnMap["provider_id"], "varchar", "provider_id 应为 VARCHAR 类型") + assert.Contains(t, columnMap["model_name"], "varchar", "model_name 应为 VARCHAR 类型") + assert.Contains(t, columnMap["request_count"], "int", "request_count 应为 INT 类型") + assert.Contains(t, columnMap["date"], "date", "date 应为 DATE 类型") + }) +} + +func TestMigration_IndexesExist(t *testing.T) { + db := SetupMySQLTestDB(t) + + t.Run("models 表索引", func(t *testing.T) { + var indexes []struct { + KeyName string + } + err := db.Raw("SHOW INDEX FROM models").Scan(&indexes).Error + require.NoError(t, err) + + indexMap := make(map[string]bool) + for _, idx := range indexes { + indexMap[idx.KeyName] = true + } + + assert.True(t, indexMap["idx_models_provider_id"], "idx_models_provider_id 索引应存在") + assert.True(t, indexMap["idx_models_model_name"], "idx_models_model_name 索引应存在") + }) + + t.Run("usage_stats 表索引", func(t *testing.T) { + var indexes []struct { + KeyName string + } + err := db.Raw("SHOW INDEX FROM usage_stats").Scan(&indexes).Error + require.NoError(t, err) + + indexMap := make(map[string]bool) + for _, idx := range indexes { + indexMap[idx.KeyName] = true + } + + assert.True(t, indexMap["idx_usage_stats_provider_model_date"], "idx_usage_stats_provider_model_date 索引应存在") + }) +} diff --git a/backend/tests/mysql/testhelper.go b/backend/tests/mysql/testhelper.go new file mode 100644 index 0000000..1ec7050 --- /dev/null +++ b/backend/tests/mysql/testhelper.go @@ -0,0 +1,160 @@ +//go:build mysql + +package mysql + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/pressly/goose/v3" + "github.com/stretchr/testify/require" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +type MySQLTestConfig struct { + Host string + Port int + User string + Password string + Database string +} + +func getMySQLTestConfig() *MySQLTestConfig { + return &MySQLTestConfig{ + Host: getEnvOrDefault("NEX_TEST_MYSQL_HOST", "localhost"), + Port: getEnvOrDefaultInt("NEX_TEST_MYSQL_PORT", 13306), + User: getEnvOrDefault("NEX_TEST_MYSQL_USER", "nex_test"), + Password: getEnvOrDefault("NEX_TEST_MYSQL_PASSWORD", "testpass"), + Database: getEnvOrDefault("NEX_TEST_MYSQL_DATABASE", "nex_test"), + } +} + +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getEnvOrDefaultInt(key string, defaultValue int) int { + if value := os.Getenv(key); value != "" { + var intValue int + if _, err := fmt.Sscanf(value, "%d", &intValue); err == nil { + return intValue + } + } + return defaultValue +} + +func SkipIfMySQLUnavailable(t *testing.T) { + t.Helper() + + cfg := getMySQLTestConfig() + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", + cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database) + + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Skipf("MySQL 不可用: %v", err) + } + defer db.Close() + + if err := db.Ping(); err != nil { + t.Skipf("MySQL 不可用: %v", err) + } +} + +func SetupMySQLTestDB(t *testing.T) *gorm.DB { + t.Helper() + + SkipIfMySQLUnavailable(t) + + cfg := getMySQLTestConfig() + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", + cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database) + + db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err, "连接 MySQL 失败") + + if err := runMigrations(db); err != nil { + require.NoError(t, err, "运行迁移失败") + } + + if err := cleanupTables(db); err != nil { + require.NoError(t, err, "清理表数据失败") + } + + sqlDB, err := db.DB() + require.NoError(t, err) + sqlDB.SetMaxIdleConns(10) + sqlDB.SetMaxOpenConns(100) + sqlDB.SetConnMaxLifetime(time.Hour) + + t.Cleanup(func() { + time.Sleep(50 * time.Millisecond) + sqlDB, err := db.DB() + if err == nil { + sqlDB.Close() + } + }) + + return db +} + +func cleanupTables(db *gorm.DB) error { + if err := db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + return err + } + if err := db.Exec("TRUNCATE TABLE usage_stats").Error; err != nil { + return err + } + if err := db.Exec("TRUNCATE TABLE models").Error; err != nil { + return err + } + if err := db.Exec("TRUNCATE TABLE providers").Error; err != nil { + return err + } + if err := db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil { + return err + } + return nil +} + +func runMigrations(db *gorm.DB) error { + sqlDB, err := db.DB() + if err != nil { + return err + } + + migrationsDir := getMigrationsDir() + if _, err := os.Stat(migrationsDir); os.IsNotExist(err) { + return fmt.Errorf("迁移目录不存在: %s", migrationsDir) + } + + goose.SetDialect("mysql") + if err := goose.Up(sqlDB, migrationsDir); err != nil { + return err + } + + return nil +} + +func getMigrationsDir() string { + _, filename, _, ok := runtime.Caller(0) + if ok { + dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", "mysql") + if abs, err := filepath.Abs(dir); err == nil { + return abs + } + } + return "./migrations/mysql" +} diff --git a/openspec/specs/config-management/spec.md b/openspec/specs/config-management/spec.md index 006b5b2..3766110 100644 --- a/openspec/specs/config-management/spec.md +++ b/openspec/specs/config-management/spec.md @@ -77,6 +77,7 @@ - **THEN** SHALL 支持 `required` 规则 - **THEN** SHALL 支持 `min`、`max` 规则 - **THEN** SHALL 支持 `oneof` 规则 +- **THEN** SHALL 支持 `required_if` 条件验证规则 #### Scenario: 验证执行 @@ -85,6 +86,17 @@ - **THEN** SHALL 返回验证错误 - **THEN** SHALL NOT 启动应用(如果验证失败) +#### Scenario: 数据库驱动条件验证 + +- **WHEN** `database.driver` 为 `sqlite` +- **THEN** SHALL 验证 `database.path` 必填 +- **THEN** SHALL NOT 要求 MySQL 字段(host/port/user/password/dbname) +- **WHEN** `database.driver` 为 `mysql` +- **THEN** SHALL 验证 `database.host` 必填 +- **THEN** SHALL 验证 `database.user` 必填 +- **THEN** SHALL 验证 `database.dbname` 必填 +- **THEN** SHALL NOT 要求 `database.path` + ### Requirement: 配置结构定义 系统 SHALL 定义清晰的配置结构。 @@ -98,7 +110,14 @@ #### Scenario: Database 配置 - **WHEN** 加载 database 配置 -- **THEN** SHALL 包含 path、max_idle_conns、max_open_conns、conn_max_lifetime 字段 +- **THEN** SHALL 包含 driver 字段(值为 `sqlite` 或 `mysql`,默认 `sqlite`) +- **THEN** SHALL 包含 path 字段(SQLite 模式下的数据库文件路径) +- **THEN** SHALL 包含 host 字段(MySQL 主机地址) +- **THEN** SHALL 包含 port 字段(MySQL 端口,默认 3306) +- **THEN** SHALL 包含 user 字段(MySQL 用户名) +- **THEN** SHALL 包含 password 字段(MySQL 密码,选填) +- **THEN** SHALL 包含 dbname 字段(MySQL 数据库名) +- **THEN** SHALL 包含 max_idle_conns、max_open_conns、conn_max_lifetime 字段 - **THEN** SHALL 使用合理的默认值 #### Scenario: Log 配置 @@ -121,7 +140,13 @@ #### Scenario: Database 默认值 - **WHEN** 使用默认配置 +- **THEN** database.driver SHALL 为 `sqlite` - **THEN** database.path SHALL 为 `~/.nex/config.db` +- **THEN** database.host SHALL 为空字符串 +- **THEN** database.port SHALL 为 3306 +- **THEN** database.user SHALL 为空字符串 +- **THEN** database.password SHALL 为空字符串 +- **THEN** database.dbname SHALL 为 `nex` - **THEN** database.max_idle_conns SHALL 为 10 - **THEN** database.max_open_conns SHALL 为 100 - **THEN** database.conn_max_lifetime SHALL 为 1h @@ -248,18 +273,38 @@ - **THEN** SHALL 在日志中记录覆盖信息 - **THEN** SHALL 显示被覆盖的配置项名称 +### Requirement: 配置文件安全 + +系统 SHALL 使用安全的文件权限保存配置文件。 + +#### Scenario: 配置文件权限 + +- **WHEN** 保存配置文件(`SaveConfig`) +- **THEN** SHALL 使用 `0600` 权限写入文件(仅 owner 可读写) +- **THEN** SHALL 防止其他用户读取配置中的 MySQL 密码等敏感信息 + ### Requirement: 配置摘要输出 系统 SHALL 在启动时输出配置摘要。 -#### Scenario: 摘要内容 +#### Scenario: SQLite 模式摘要 -- **WHEN** 配置加载完成 +- **WHEN** `database.driver` 为 `sqlite` - **THEN** SHALL 打印关键配置项(端口、数据库路径、日志级别等) - **THEN** SHALL 打印配置文件路径 - **THEN** SHALL 打印环境变量数量 - **THEN** SHALL 打印 CLI 参数数量 +#### Scenario: MySQL 模式摘要 + +- **WHEN** `database.driver` 为 `mysql` +- **THEN** SHALL 打印关键配置项(端口、数据库类型、数据库地址、日志级别等) +- **THEN** SHALL 打印数据库地址格式为 `{host}:{port}/{dbname}` +- **THEN** SHALL 不打印密码 +- **THEN** SHALL 打印配置文件路径 +- **THEN** SHALL 打印环境变量数量 +- **THEN** SHALL 打印 CLI 参数数量 + #### Scenario: 摘要格式 - **WHEN** 打印配置摘要 @@ -297,7 +342,7 @@ - **WHEN** 使用服务器相关参数 - **THEN** SHALL 支持 `--server-port`、`--server-read-timeout`、`--server-write-timeout` - **WHEN** 使用数据库相关参数 -- **THEN** SHALL 支持 `--database-path`、`--database-max-idle-conns`、`--database-max-open-conns`、`--database-conn-max-lifetime` +- **THEN** SHALL 支持 `--database-driver`、`--database-path`、`--database-host`、`--database-port`、`--database-user`、`--database-password`、`--database-dbname`、`--database-max-idle-conns`、`--database-max-open-conns`、`--database-conn-max-lifetime` - **WHEN** 使用日志相关参数 - **THEN** SHALL 支持 `--log-level`、`--log-path`、`--log-max-size`、`--log-max-backups`、`--log-max-age`、`--log-compress` @@ -348,7 +393,7 @@ - **WHEN** 设置服务器相关环境变量 - **THEN** SHALL 支持 `NEX_SERVER_PORT`、`NEX_SERVER_READ_TIMEOUT`、`NEX_SERVER_WRITE_TIMEOUT` - **WHEN** 设置数据库相关环境变量 -- **THEN** SHALL 支持 `NEX_DATABASE_PATH`、`NEX_DATABASE_MAX_IDLE_CONNS`、`NEX_DATABASE_MAX_OPEN_CONNS`、`NEX_DATABASE_CONN_MAX_LIFETIME` +- **THEN** SHALL 支持 `NEX_DATABASE_DRIVER`、`NEX_DATABASE_PATH`、`NEX_DATABASE_HOST`、`NEX_DATABASE_PORT`、`NEX_DATABASE_USER`、`NEX_DATABASE_PASSWORD`、`NEX_DATABASE_DBNAME`、`NEX_DATABASE_MAX_IDLE_CONNS`、`NEX_DATABASE_MAX_OPEN_CONNS`、`NEX_DATABASE_CONN_MAX_LIFETIME` - **WHEN** 设置日志相关环境变量 - **THEN** SHALL 支持 `NEX_LOG_LEVEL`、`NEX_LOG_PATH`、`NEX_LOG_MAX_SIZE`、`NEX_LOG_MAX_BACKUPS`、`NEX_LOG_MAX_AGE`、`NEX_LOG_COMPRESS` diff --git a/openspec/specs/database-migration/spec.md b/openspec/specs/database-migration/spec.md index 733b4d1..ff13042 100644 --- a/openspec/specs/database-migration/spec.md +++ b/openspec/specs/database-migration/spec.md @@ -46,6 +46,14 @@ - **THEN** SHALL 删除所有表和索引 - **THEN** SHALL 按正确顺序删除(避免外键约束错误) +#### Scenario: 按数据库方言拆分迁移目录 + +- **WHEN** 组织迁移文件 +- **THEN** SHALL 将 SQLite 方言迁移文件存储在 `migrations/sqlite/` 目录 +- **THEN** SHALL 将 MySQL 方言迁移文件存储在 `migrations/mysql/` 目录 +- **THEN** SHALL 两个目录维护独立的版本号序列 +- **THEN** SHALL 两个目录的迁移文件内容在逻辑上一致(相同的表结构和约束),但使用各自方言的 DDL + ### Requirement: models 表 schema 变更 系统 SHALL 在初始迁移脚本中直接创建新的 models 表结构(服务未上线,无需考虑数据迁移,迁移脚本已合并为单个初始迁移文件)。 @@ -63,28 +71,37 @@ #### Scenario: 迁移 up 命令 -- **WHEN** 执行 `make migrate-up` +- **WHEN** 执行 `make backend-migrate-up` - **THEN** SHALL 执行所有待执行的迁移 +- **THEN** SHALL 使用 `DB_DRIVER` 变量选择方言目录(默认 `sqlite3`) +- **THEN** SHALL 使用 `DB_DSN` 变量作为数据库连接串 - **THEN** SHALL 显示迁移进度 #### Scenario: 迁移 down 命令 -- **WHEN** 执行 `make migrate-down` +- **WHEN** 执行 `make backend-migrate-down` - **THEN** SHALL 回滚最后一个迁移 +- **THEN** SHALL 使用 `DB_DRIVER` 和 `DB_DSN` 变量 - **THEN** SHALL 显示回滚进度 #### Scenario: 迁移状态命令 -- **WHEN** 执行 `make migrate-status` +- **WHEN** 执行 `make backend-migrate-status` - **THEN** SHALL 显示当前迁移状态 - **THEN** SHALL 显示已执行和待执行的迁移 #### Scenario: 创建迁移命令 -- **WHEN** 执行 `make migrate-create name=` -- **THEN** SHALL 创建新的迁移文件模板 +- **WHEN** 执行 `make backend-migrate-create` +- **THEN** SHALL 同时在 `migrations/sqlite/` 和 `migrations/mysql/` 两个目录创建新的迁移文件模板 - **THEN** SHALL 使用递增的版本号 +#### Scenario: MySQL 迁移命令使用 + +- **WHEN** 使用 MySQL 驱动执行迁移 +- **THEN** SHALL 设置 `DB_DRIVER=mysql` +- **THEN** SHALL 设置 `DB_DSN` 为 MySQL 连接串(如 `user:pass@tcp(localhost:3306)/nex`) + ### Requirement: 应用启动时迁移 应用 SHALL 在启动时执行迁移。 @@ -92,6 +109,9 @@ #### Scenario: 自动迁移 - **WHEN** 应用启动 +- **THEN** SHALL 根据 `database.driver` 配置选择对应的迁移目录和 goose dialect +- **THEN** SHALL 在 `driver=sqlite` 时使用 `migrations/sqlite/` 目录,goose dialect 为 `sqlite3` +- **THEN** SHALL 在 `driver=mysql` 时使用 `migrations/mysql/` 目录,goose dialect 为 `mysql` - **THEN** SHALL 自动执行待执行的迁移 - **THEN** SHALL 在迁移失败时拒绝启动 - **THEN** SHALL 记录迁移日志 @@ -149,5 +169,5 @@ #### Scenario: 迁移文件存储 - **WHEN** 创建迁移文件 -- **THEN** SHALL 存储在 migrations/ 目录 +- **THEN** SHALL 按 SQL 方言存储在对应子目录(`migrations/sqlite/` 或 `migrations/mysql/`) - **THEN** SHALL 提交到版本控制系统 diff --git a/openspec/specs/mysql-driver/spec.md b/openspec/specs/mysql-driver/spec.md new file mode 100644 index 0000000..dbcb979 --- /dev/null +++ b/openspec/specs/mysql-driver/spec.md @@ -0,0 +1,107 @@ +# MySQL Driver + +## Purpose + +支持 MySQL 作为可选数据库后端,通过配置选择 sqlite 或 mysql 驱动,提供 MySQL 连接管理、初始化和方言迁移文件。 + +## Requirements + +### Requirement: MySQL 数据库驱动支持 + +系统 SHALL 支持通过配置项 `database.driver` 选择 `sqlite` 或 `mysql` 数据库驱动,默认值为 `sqlite`。 + +#### Scenario: 默认使用 SQLite 驱动 + +- **WHEN** 配置中未指定 `database.driver` +- **THEN** SHALL 使用 `sqlite` 作为数据库驱动 +- **THEN** SHALL 行为与现有逻辑完全一致 + +#### Scenario: 配置 MySQL 驱动 + +- **WHEN** 配置 `database.driver` 设置为 `mysql` +- **THEN** SHALL 使用 MySQL 驱动连接远程数据库 +- **THEN** SHALL 使用 `gorm.io/driver/mysql` 打开连接 +- **THEN** SHALL 构建 DSN 格式为 `{user}:{password}@tcp({host}:{port})/{dbname}?charset=utf8mb4&parseTime=true&loc=Local` + +#### Scenario: driver 值不合法 + +- **WHEN** 配置 `database.driver` 不是 `sqlite` 或 `mysql` +- **THEN** SHALL 配置验证失败,拒绝启动 + +### Requirement: MySQL 连接配置 + +系统 SHALL 在 `DatabaseConfig` 中支持 MySQL 连接参数。 + +#### Scenario: MySQL 连接参数字段 + +- **WHEN** `database.driver` 为 `mysql` +- **THEN** SHALL 读取 `host`(MySQL 主机地址,必填) +- **THEN** SHALL 读取 `port`(MySQL 端口,默认 3306) +- **THEN** SHALL 读取 `user`(MySQL 用户名,必填) +- **THEN** SHALL 读取 `password`(MySQL 密码,选填) +- **THEN** SHALL 读取 `dbname`(数据库名,必填) + +#### Scenario: SQLite 模式忽略 MySQL 参数 + +- **WHEN** `database.driver` 为 `sqlite` +- **THEN** SHALL 忽略 MySQL 相关配置字段(host/port/user/password/dbname) +- **THEN** SHALL 仅使用 `path` 字段作为数据库文件路径 + +#### Scenario: MySQL 模式忽略 SQLite 参数 + +- **WHEN** `database.driver` 为 `mysql` +- **THEN** SHALL 忽略 `path` 字段 + +### Requirement: 数据库初始化公共包 + +系统 SHALL 提供 `internal/database` 公共包,封装数据库初始化、迁移执行和连接关闭逻辑,供 `cmd/server` 和 `cmd/desktop` 共同调用。 + +#### Scenario: 公共包 Init 函数 + +- **WHEN** 调用 `database.Init(cfg, logger)` +- **THEN** SHALL 根据 `cfg.Driver` 选择对应的 GORM 驱动打开连接 +- **THEN** SHALL 执行对应方言的 goose 迁移 +- **THEN** SHALL 配置连接池参数 +- **THEN** SHALL 在 `driver=sqlite` 时执行 `PRAGMA journal_mode=WAL` +- **THEN** SHALL 在 `driver=mysql` 时跳过 SQLite 专有 PRAGMA +- **THEN** SHALL 返回 `*gorm.DB` 实例 + +#### Scenario: 公共包 Close 函数 + +- **WHEN** 调用 `database.Close(db)` +- **THEN** SHALL 获取底层 `sql.DB` 并关闭连接 + +#### Scenario: 迁移目录选择 + +- **WHEN** 执行迁移 +- **THEN** SHALL 在 `driver=sqlite` 时使用 `migrations/sqlite/` 目录,goose dialect 为 `sqlite3` +- **THEN** SHALL 在 `driver=mysql` 时使用 `migrations/mysql/` 目录,goose dialect 为 `mysql` + +### Requirement: MySQL 方言迁移文件 + +系统 SHALL 提供 MySQL 方言的初始迁移文件 `migrations/mysql/20260421000001_initial_schema.sql`。 + +#### Scenario: providers 表 + +- **WHEN** 执行 MySQL 初始迁移 +- **THEN** SHALL 创建 `providers` 表,字段:`id VARCHAR(36) PRIMARY KEY`、`name VARCHAR(255) NOT NULL`、`api_key VARCHAR(255) NOT NULL`、`base_url VARCHAR(255) NOT NULL`、`protocol VARCHAR(50) DEFAULT 'openai'`、`enabled BOOLEAN DEFAULT TRUE`、`created_at DATETIME(3)`、`updated_at DATETIME(3)` + +#### Scenario: models 表 + +- **WHEN** 执行 MySQL 初始迁移 +- **THEN** SHALL 创建 `models` 表,字段:`id VARCHAR(36) PRIMARY KEY`、`provider_id VARCHAR(36) NOT NULL`、`model_name VARCHAR(255) NOT NULL`、`enabled BOOLEAN DEFAULT TRUE`、`created_at DATETIME(3)` +- **THEN** SHALL 创建 `UNIQUE(provider_id, model_name)` 约束 +- **THEN** SHALL 创建 `FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE` 约束 +- **THEN** SHALL 创建 `idx_models_provider_id` 和 `idx_models_model_name` 索引 + +#### Scenario: usage_stats 表 + +- **WHEN** 执行 MySQL 初始迁移 +- **THEN** SHALL 创建 `usage_stats` 表,字段:`id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY`、`provider_id VARCHAR(36) NOT NULL`、`model_name VARCHAR(255) NOT NULL`、`request_count INT DEFAULT 0`、`date DATE NOT NULL` +- **THEN** SHALL 创建 `UNIQUE(provider_id, model_name, date)` 约束 +- **THEN** SHALL 创建 `idx_usage_stats_provider_model_date` 复合索引 + +#### Scenario: Down 迁移 + +- **WHEN** 执行 MySQL down 迁移 +- **THEN** SHALL 按正确顺序删除索引和表(usage_stats → models → providers) diff --git a/openspec/specs/mysql-testing/spec.md b/openspec/specs/mysql-testing/spec.md new file mode 100644 index 0000000..5d6ecad --- /dev/null +++ b/openspec/specs/mysql-testing/spec.md @@ -0,0 +1,104 @@ +# MySQL Testing + +## Purpose + +提供 MySQL 数据库专项测试能力,验证迁移正确性、外键约束、并发写入等数据库特定行为。 + +## Requirements + +### Requirement: MySQL 测试环境可启动 + +系统 SHALL 提供 Docker Compose 配置以启动 MySQL 8.0 测试环境。 + +#### Scenario: 启动 MySQL 测试容器 +- **WHEN** 执行 `make test-mysql-up` +- **THEN** 启动 MySQL 8.0 容器,端口 13306 +- **AND** 创建数据库 `nex_test` +- **AND** 容器数据存储在内存盘(tmpfs) + +#### Scenario: 销毁 MySQL 测试容器 +- **WHEN** 执行 `make test-mysql-down` +- **THEN** 停止并删除容器 +- **AND** 所有数据被销毁 + +### Requirement: MySQL 测试可通过 build tag 控制 + +MySQL 测试 SHALL 使用 `// +build mysql` build tag,默认不运行。 + +#### Scenario: 默认测试不包含 MySQL 测试 +- **WHEN** 执行 `go test ./...` +- **THEN** 不运行 `tests/mysql/` 下的测试 + +#### Scenario: 启用 MySQL 测试 +- **WHEN** 执行 `go test -tags=mysql ./tests/mysql/...` +- **THEN** 运行所有 MySQL 测试 + +### Requirement: MySQL 迁移正确执行 + +MySQL 测试 SHALL 验证迁移脚本在 MySQL 环境下正确执行。 + +#### Scenario: 迁移创建所有表 +- **WHEN** 运行 MySQL 迁移 +- **THEN** 创建 `providers`、`models`、`usage_stats` 表 +- **AND** 字段类型符合 MySQL 迁移文件定义(VARCHAR、DATETIME(3)、BOOLEAN 等) +- **AND** 索引 `idx_models_provider_id`、`idx_models_model_name`、`idx_usage_stats_provider_model_date` 创建成功 + +#### Scenario: 迁移可重复执行 +- **WHEN** 在已迁移的数据库上再次运行迁移 +- **THEN** 不报错,数据库状态不变 + +### Requirement: MySQL 外键约束生效 + +MySQL 测试 SHALL 验证外键约束行为符合预期。 + +#### Scenario: 外键约束阻止无效引用 +- **WHEN** 创建 model 时 `provider_id` 不存在 +- **THEN** 操作失败,返回外键约束错误 + +#### Scenario: 级联删除生效 +- **WHEN** 删除 provider +- **THEN** 该 provider 的所有 models 被级联删除 + +### Requirement: MySQL UNIQUE 约束生效 + +MySQL 测试 SHALL 验证 UNIQUE 约束行为符合预期。 + +#### Scenario: models 表 UNIQUE 约束 +- **WHEN** 尝试创建相同 `(provider_id, model_name)` 组合的 model +- **THEN** 操作失败,返回唯一约束错误 + +#### Scenario: usage_stats 表 UNIQUE 约束 +- **WHEN** 尝试创建相同 `(provider_id, model_name, date)` 组合的 usage_stats +- **THEN** 操作失败,返回唯一约束错误 + +### Requirement: MySQL 并发写入正确 + +MySQL 测试 SHALL 验证并发写入不丢失数据。 + +#### Scenario: 并发记录 usage_stats +- **WHEN** 10 个 goroutine 并发调用 `statsRepo.Record(providerID, modelName)` +- **THEN** 最终 `request_count` 等于 10 +- **AND** 无数据丢失或重复 + +#### Scenario: 并发创建相同 provider +- **WHEN** 10 个 goroutine 并发创建相同 ID 的 provider +- **THEN** 仅 1 个成功,其他 9 个失败 + +#### Scenario: 并发创建相同 model +- **WHEN** 10 个 goroutine 并发创建相同 `(provider_id, model_name)` 的 model +- **THEN** 仅 1 个成功,其他 9 个失败 + +### Requirement: MySQL 测试命令完整 + +Makefile SHALL 提供完整的 MySQL 测试命令。 + +#### Scenario: 完整测试流程 +- **WHEN** 执行 `make test-mysql` +- **THEN** 启动 Docker MySQL +- **AND** 等待 MySQL 就绪 +- **AND** 运行所有 MySQL 测试 +- **AND** 销毁容器 + +#### Scenario: 快速测试(容器已运行) +- **WHEN** 执行 `make test-mysql-quick` +- **THEN** 直接运行测试,不管理容器生命周期 diff --git a/openspec/specs/usage-statistics/spec.md b/openspec/specs/usage-statistics/spec.md index 7c5d4d0..abca1fa 100644 --- a/openspec/specs/usage-statistics/spec.md +++ b/openspec/specs/usage-statistics/spec.md @@ -93,7 +93,21 @@ - **WHEN** 同时处理多个并发请求 - **THEN** 网关 SHALL 使用原子操作正确增加每个请求的计数 - **THEN** 不 SHALL 因并发写入而丢失统计 -- **THEN** SHALL 使用 StatsBuffer 的内存计数器 +- **THEN** SHALL 使用 upsert 操作保证原子性 + +#### Scenario: 并发调用 Record 方法 + +- **WHEN** 多个 goroutine 并发调用 StatsRepository.Record +- **THEN** SHALL 使用 INSERT ... ON DUPLICATE KEY UPDATE (MySQL) 或 INSERT ... ON CONFLICT DO UPDATE (SQLite) +- **THEN** SHALL 保证所有并发调用的计数都被正确累加 +- **THEN** 不 SHALL 因 UNIQUE 约束冲突而丢失数据 + +#### Scenario: 并发调用 BatchUpdate 方法 + +- **WHEN** 多个 goroutine 并发调用 StatsRepository.BatchUpdate +- **THEN** SHALL 使用 upsert 操作保证原子性 +- **THEN** SHALL 正确累加所有 delta 值 +- **THEN** 不 SHALL 因并发写入而丢失统计 ### Requirement: 使用 service 层处理业务逻辑 @@ -125,14 +139,14 @@ Service SHALL 通过 StatsRepository 访问数据。 - **WHEN** StatsBuffer 刷新统计 - **THEN** SHALL 调用 StatsRepository.BatchUpdate -- **THEN** SHALL 使用事务更新或创建统计记录 +- **THEN** SHALL 使用 upsert 操作更新或创建统计记录 - **THEN** SHALL 支持增量更新(request_count + delta) -#### Scenario: 事务处理 +#### Scenario: upsert 操作 - **WHEN** 记录统计 -- **THEN** SHALL 在 repository 层使用数据库事务 -- **THEN** SHALL 确保并发安全 +- **THEN** SHALL 在 repository 层使用 upsert 操作 +- **THEN** SHALL 保证原子性和并发安全 ### Requirement: 统计查询优化 @@ -168,11 +182,18 @@ StatsRepository SHALL 新增 BatchUpdate 方法支持批量增量更新。 #### Scenario: BatchUpdate 更新现有记录 - **WHEN** 调用 BatchUpdate 且当日记录存在 -- **THEN** SHALL 使用事务更新 request_count = request_count + delta +- **THEN** SHALL 使用 upsert 操作更新 request_count = request_count + delta +- **THEN** SHALL 保证原子性,无竞态条件 - **THEN** SHALL 不创建新记录 #### Scenario: BatchUpdate 创建新记录 - **WHEN** 调用 BatchUpdate 且当日记录不存在 - **THEN** SHALL 创建新记录,request_count = delta -- **THEN** SHALL 使用事务保证原子性 +- **THEN** SHALL 使用 upsert 操作保证原子性 + +#### Scenario: BatchUpdate 并发安全 + +- **WHEN** 多个 BatchUpdate 调用同时执行 +- **THEN** SHALL 保证所有 delta 都被正确累加 +- **THEN** SHALL 不因并发冲突而丢失数据