package main import ( "context" "fmt" "log" "net/http" "os" "os/signal" "path/filepath" "runtime" "syscall" "time" "github.com/gin-gonic/gin" "github.com/pressly/goose/v3" "go.uber.org/zap" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" "nex/backend/internal/config" "nex/backend/internal/handler" "nex/backend/internal/handler/middleware" "nex/backend/internal/provider" "nex/backend/internal/repository" "nex/backend/internal/service" pkgLogger "nex/backend/pkg/logger" ) func main() { // 1. 加载配置 cfg, err := config.LoadConfig() if err != nil { log.Fatalf("加载配置失败: %v", err) } if err := cfg.Validate(); err != nil { log.Fatalf("配置验证失败: %v", err) } // 2. 初始化日志 zapLogger, err := pkgLogger.New(pkgLogger.Config{ Level: cfg.Log.Level, Path: cfg.Log.Path, MaxSize: cfg.Log.MaxSize, MaxBackups: cfg.Log.MaxBackups, MaxAge: cfg.Log.MaxAge, Compress: cfg.Log.Compress, }) if err != nil { log.Fatalf("初始化日志失败: %v", err) } defer zapLogger.Sync() // 3. 初始化数据库 db, err := initDatabase(cfg) if err != nil { zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error())) } defer closeDB(db) // 4. 初始化 repository 层 providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) // 5. 初始化 service 层 providerService := service.NewProviderService(providerRepo) modelService := service.NewModelService(modelRepo, providerRepo) routingService := service.NewRoutingService(modelRepo, providerRepo) statsService := service.NewStatsService(statsRepo) // 6. 初始化 provider client providerClient := provider.NewClient() // 7. 初始化 handler 层 openaiHandler := handler.NewOpenAIHandler(providerClient, routingService, statsService) anthropicHandler := handler.NewAnthropicHandler(providerClient, routingService, statsService) providerHandler := handler.NewProviderHandler(providerService) modelHandler := handler.NewModelHandler(modelService) statsHandler := handler.NewStatsHandler(statsService) // 8. 创建 Gin 引擎 gin.SetMode(gin.ReleaseMode) r := gin.New() // 注册中间件(按正确顺序) r.Use(middleware.RequestID()) r.Use(middleware.Recovery(zapLogger)) r.Use(middleware.Logging(zapLogger)) r.Use(middleware.CORS()) // 注册路由 setupRoutes(r, openaiHandler, anthropicHandler, providerHandler, modelHandler, statsHandler) // 9. 启动服务器 srv := &http.Server{ Addr: formatAddr(cfg.Server.Port), Handler: r, ReadTimeout: cfg.Server.ReadTimeout, WriteTimeout: cfg.Server.WriteTimeout, } go func() { zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) } }() // 等待中断信号 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit zapLogger.Info("正在关闭服务器...") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error())) } zapLogger.Info("服务器已关闭") } func initDatabase(cfg *config.Config) (*gorm.DB, error) { db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), }) if err != nil { return nil, err } if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { log.Printf("警告: 启用 WAL 模式失败: %v", err) } // 运行数据库迁移 if err := runMigrations(db); err != nil { return nil, fmt.Errorf("数据库迁移失败: %w", err) } // 配置连接池 sqlDB, err := db.DB() if err != nil { return nil, err } sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns) sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns) sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime) // 记录连接池状态 log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v", cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime) return db, nil } // runMigrations 使用 goose 执行数据库迁移 func runMigrations(db *gorm.DB) error { sqlDB, err := db.DB() if err != nil { return err } migrationsDir := getMigrationsDir() if _, err := os.Stat(migrationsDir); os.IsNotExist(err) { return fmt.Errorf("迁移目录不存在: %s", migrationsDir) } goose.SetDialect("sqlite3") if err := goose.Up(sqlDB, migrationsDir); err != nil { return err } return nil } // getMigrationsDir 获取迁移文件目录路径 func getMigrationsDir() string { // 从可执行文件位置推断迁移目录 _, filename, _, ok := runtime.Caller(0) if ok { // cmd/server/main.go → backend/ → backend/migrations/ dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations") if abs, err := filepath.Abs(dir); err == nil { return abs } } // 回退到相对路径 return "./migrations" } func closeDB(db *gorm.DB) { sqlDB, err := db.DB() if err != nil { return } sqlDB.Close() } func formatAddr(port int) string { return fmt.Sprintf(":%d", port) } func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicHandler *handler.AnthropicHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { // OpenAI 协议代理 r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions) // Anthropic 协议代理 r.POST("/v1/messages", anthropicHandler.HandleMessages) // 供应商管理 API providers := r.Group("/api/providers") { providers.GET("", providerHandler.ListProviders) providers.POST("", providerHandler.CreateProvider) providers.GET("/:id", providerHandler.GetProvider) providers.PUT("/:id", providerHandler.UpdateProvider) providers.DELETE("/:id", providerHandler.DeleteProvider) } // 模型管理 API models := r.Group("/api/models") { models.GET("", modelHandler.ListModels) models.POST("", modelHandler.CreateModel) models.GET("/:id", modelHandler.GetModel) models.PUT("/:id", modelHandler.UpdateModel) models.DELETE("/:id", modelHandler.DeleteModel) } // 统计查询 API stats := r.Group("/api/stats") { stats.GET("", statsHandler.GetStats) stats.GET("/aggregate", statsHandler.AggregateStats) } // 健康检查 r.GET("/health", func(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) }) }