package main import ( "context" "log" "net/http" "os" "os/signal" "syscall" "time" "github.com/gin-gonic/gin" "nex/backend/internal/config" "nex/backend/internal/handler" ) func main() { // 初始化数据库 if err := config.InitDB(); err != nil { log.Fatalf("初始化数据库失败: %v", err) } defer config.CloseDB() // 创建 Gin 引擎 gin.SetMode(gin.ReleaseMode) r := gin.Default() // 配置 CORS r.Use(func(c *gin.Context) { c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) return } c.Next() }) // 注册路由 setupRoutes(r) // 创建 HTTP 服务器 srv := &http.Server{ Addr: ":9826", Handler: r, } // 启动服务器(在 goroutine 中) go func() { log.Printf("AI Gateway 启动在端口 9826") if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("服务器启动失败: %v", err) } }() // 等待中断信号以优雅关闭服务器 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit log.Println("正在关闭服务器...") // 给服务器 5 秒时间完成当前请求 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { log.Fatal("服务器强制关闭:", err) } log.Println("服务器已关闭") } // setupRoutes 配置路由 func setupRoutes(r *gin.Engine) { // OpenAI 协议代理 openaiHandler := handler.NewOpenAIHandler() r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions) // Anthropic 协议代理 anthropicHandler := handler.NewAnthropicHandler() r.POST("/v1/messages", anthropicHandler.HandleMessages) // 供应商管理 API providerHandler := handler.NewProviderHandler() providers := r.Group("/api/providers") { providers.GET("", providerHandler.ListProviders) providers.POST("", providerHandler.CreateProvider) providers.GET("/:id", providerHandler.GetProvider) providers.PUT("/:id", providerHandler.UpdateProvider) providers.DELETE("/:id", providerHandler.DeleteProvider) } // 模型管理 API modelHandler := handler.NewModelHandler() models := r.Group("/api/models") { models.GET("", modelHandler.ListModels) models.POST("", modelHandler.CreateModel) models.GET("/:id", modelHandler.GetModel) models.PUT("/:id", modelHandler.UpdateModel) models.DELETE("/:id", modelHandler.DeleteModel) } // 统计查询 API statsHandler := handler.NewStatsHandler() stats := r.Group("/api/stats") { stats.GET("", statsHandler.GetStats) stats.GET("/aggregate", statsHandler.AggregateStats) } // 健康检查 r.GET("/health", func(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) }) }