package config import ( "fmt" "os" "path/filepath" "strings" "time" "github.com/go-playground/validator/v10" "github.com/mitchellh/mapstructure" "github.com/spf13/pflag" "github.com/spf13/viper" "go.uber.org/zap" "gopkg.in/yaml.v3" appErrors "nex/backend/pkg/errors" ) // Config 应用配置 type Config struct { Server ServerConfig `yaml:"server" mapstructure:"server" validate:"required"` Database DatabaseConfig `yaml:"database" mapstructure:"database" validate:"required"` Log LogConfig `yaml:"log" mapstructure:"log" validate:"required"` } // ServerConfig 服务器配置 type ServerConfig struct { Port int `yaml:"port" mapstructure:"port" validate:"required,min=1,max=65535"` ReadTimeout time.Duration `yaml:"read_timeout" mapstructure:"read_timeout" validate:"required"` WriteTimeout time.Duration `yaml:"write_timeout" mapstructure:"write_timeout" validate:"required"` } // DatabaseConfig 数据库配置 type DatabaseConfig struct { Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"` Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"` Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"` Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"` User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"` Password string `yaml:"password" mapstructure:"password"` DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"` MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"` MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"` ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"` } // LogConfig 日志配置 type LogConfig struct { Level string `yaml:"level" mapstructure:"level" validate:"required,oneof=debug info warn error"` Path string `yaml:"path" mapstructure:"path" validate:"required"` MaxSize int `yaml:"max_size" mapstructure:"max_size" validate:"required,min=1"` MaxBackups int `yaml:"max_backups" mapstructure:"max_backups" validate:"required,min=0"` MaxAge int `yaml:"max_age" mapstructure:"max_age" validate:"required,min=0"` Compress bool `yaml:"compress" mapstructure:"compress"` } // DefaultConfig returns default config values func DefaultConfig() *Config { // Use home dir for default paths homeDir, err := os.UserHomeDir() if err != nil { homeDir = "." } nexDir := filepath.Join(homeDir, ".nex") return &Config{ Server: ServerConfig{ Port: 9826, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, }, Database: DatabaseConfig{ Driver: "sqlite", Path: filepath.Join(nexDir, "config.db"), Host: "", Port: 3306, User: "", Password: "", DBName: "nex", MaxIdleConns: 10, MaxOpenConns: 100, ConnMaxLifetime: 1 * time.Hour, }, Log: LogConfig{ Level: "info", Path: filepath.Join(nexDir, "log"), MaxSize: 100, MaxBackups: 10, MaxAge: 30, Compress: true, }, } } // GetConfigDir 获取配置目录路径(~/.nex/) func GetConfigDir() (string, error) { homeDir, err := os.UserHomeDir() if err != nil { return "", err } configDir := filepath.Join(homeDir, ".nex") if err := os.MkdirAll(configDir, 0o755); err != nil { return "", err } return configDir, nil } // GetDBPath 获取数据库文件路径 func GetDBPath() (string, error) { configDir, err := GetConfigDir() if err != nil { return "", err } return filepath.Join(configDir, "config.db"), nil } // GetConfigPath 获取配置文件路径 func GetConfigPath() (string, error) { configDir, err := GetConfigDir() if err != nil { return "", err } return filepath.Join(configDir, "config.yaml"), nil } // setupDefaults 设置默认配置值 func setupDefaults(v *viper.Viper) { homeDir, err := os.UserHomeDir() if err != nil { homeDir = "." } nexDir := filepath.Join(homeDir, ".nex") v.SetDefault("server.port", 9826) v.SetDefault("server.read_timeout", "30s") v.SetDefault("server.write_timeout", "30s") v.SetDefault("database.driver", "sqlite") v.SetDefault("database.path", filepath.Join(nexDir, "config.db")) v.SetDefault("database.host", "") v.SetDefault("database.port", 3306) v.SetDefault("database.user", "") v.SetDefault("database.password", "") v.SetDefault("database.dbname", "nex") v.SetDefault("database.max_idle_conns", 10) v.SetDefault("database.max_open_conns", 100) v.SetDefault("database.conn_max_lifetime", "1h") v.SetDefault("log.level", "info") v.SetDefault("log.path", filepath.Join(nexDir, "log")) v.SetDefault("log.max_size", 100) v.SetDefault("log.max_backups", 10) v.SetDefault("log.max_age", 30) v.SetDefault("log.compress", true) } // setupFlags 定义和绑定 CLI 参数 func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) { // 定义所有配置项的 CLI 参数 // 注意:这里不设置默认值,让 viper 的默认值生效 flagSet.Int("server-port", 0, "服务器端口") flagSet.Duration("server-read-timeout", 0, "读超时") flagSet.Duration("server-write-timeout", 0, "写超时") flagSet.String("database-driver", "", "数据库驱动:sqlite/mysql") flagSet.String("database-path", "", "数据库文件路径") flagSet.String("database-host", "", "MySQL 主机地址") flagSet.Int("database-port", 0, "MySQL 端口") flagSet.String("database-user", "", "MySQL 用户名") flagSet.String("database-password", "", "MySQL 密码") flagSet.String("database-dbname", "", "MySQL 数据库名") flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数") flagSet.Int("database-max-open-conns", 0, "最大打开连接数") flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期") flagSet.String("log-level", "", "日志级别:debug/info/warn/error") flagSet.String("log-path", "", "日志文件目录") flagSet.Int("log-max-size", 0, "单个日志文件最大大小 MB") flagSet.Int("log-max-backups", 0, "保留的旧日志文件最大数量") flagSet.Int("log-max-age", 0, "保留旧日志文件的最大天数") flagSet.Bool("log-compress", false, "是否压缩旧日志文件") // 绑定所有 flag 到 viper // 注意:必须在设置默认值之后绑定 bindPFlag(v, "server.port", flagSet.Lookup("server-port")) bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout")) bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout")) bindPFlag(v, "database.driver", flagSet.Lookup("database-driver")) bindPFlag(v, "database.path", flagSet.Lookup("database-path")) bindPFlag(v, "database.host", flagSet.Lookup("database-host")) bindPFlag(v, "database.port", flagSet.Lookup("database-port")) bindPFlag(v, "database.user", flagSet.Lookup("database-user")) bindPFlag(v, "database.password", flagSet.Lookup("database-password")) bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname")) bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns")) bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns")) bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime")) bindPFlag(v, "log.level", flagSet.Lookup("log-level")) bindPFlag(v, "log.path", flagSet.Lookup("log-path")) bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size")) bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups")) bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age")) bindPFlag(v, "log.compress", flagSet.Lookup("log-compress")) } func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) { if err := v.BindPFlag(key, flag); err != nil { panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err)) } } // setupEnv 绑定环境变量 func setupEnv(v *viper.Viper) { v.SetEnvPrefix("NEX") v.AutomaticEnv() v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) } // setupConfigFile 读取配置文件 func setupConfigFile(v *viper.Viper, configPath string) error { v.SetConfigFile(configPath) v.SetConfigType("yaml") if err := v.ReadInConfig(); err != nil { if os.IsNotExist(err) { return nil } return appErrors.Wrap(appErrors.ErrInternal, err) } return nil } // loadOptions 控制配置加载器行为 type loadOptions struct { configPathOverride string useCLI bool useEnv bool useConfigFlag bool } // resolveConfigPath 根据 loadOptions 解析 CLI 参数并返回最终配置文件路径 func resolveConfigPath(v *viper.Viper, opts loadOptions) (string, error) { configPath := opts.configPathOverride if !opts.useCLI && !opts.useConfigFlag { return configPath, nil } flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError) if opts.useConfigFlag { flagSet.String("config", opts.configPathOverride, "配置文件路径") } if opts.useCLI { setupFlags(v, flagSet) } if err := flagSet.Parse(os.Args[1:]); err != nil { return "", appErrors.Wrap(appErrors.ErrInvalidRequest, err) } if opts.useConfigFlag { if f, err := flagSet.GetString("config"); err == nil && f != "" { configPath = f } } return configPath, nil } // loadConfig 共享配置加载逻辑,通过 loadOptions 控制是否启用 CLI、环境变量和 --config 覆盖 func loadConfig(opts loadOptions) (*Config, error) { v := viper.New() setupDefaults(v) configPath, err := resolveConfigPath(v, opts) if err != nil { return nil, err } if opts.useEnv { setupEnv(v) } if err := setupConfigFile(v, configPath); err != nil { return nil, err } cfg := &Config{} if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToSliceHookFunc(","), ))); err != nil { return nil, appErrors.Wrap(appErrors.ErrInternal, err) } if err := cfg.Validate(); err != nil { return nil, err } return cfg, nil } // LoadServerConfig 为 server 入口加载配置,支持 CLI 参数、环境变量和 --config func LoadServerConfig() (*Config, error) { configPath, err := GetConfigPath() if err != nil { return nil, appErrors.Wrap(appErrors.ErrInternal, err) } return loadConfig(loadOptions{ configPathOverride: configPath, useCLI: true, useEnv: true, useConfigFlag: true, }) } // LoadDesktopConfig 为 desktop 入口加载配置,固定使用默认配置文件,不支持 CLI、环境变量和 --config func LoadDesktopConfig() (*Config, error) { configPath, err := GetConfigPath() if err != nil { return nil, appErrors.Wrap(appErrors.ErrInternal, err) } return loadConfig(loadOptions{ configPathOverride: configPath, useCLI: false, useEnv: false, useConfigFlag: false, }) } // LoadConfig loads config from YAML file. // 向后兼容,等同于 LoadServerConfig。 func LoadConfig() (*Config, error) { return LoadServerConfig() } // LoadConfigFromPath 从指定路径加载配置。 // 保留向后兼容,沿用 server 语义(支持 CLI、env 和 --config 覆盖)。 func LoadConfigFromPath(configPath string) (*Config, error) { return loadConfig(loadOptions{ configPathOverride: configPath, useCLI: true, useEnv: true, useConfigFlag: true, }) } // LoadDesktopConfigAtPath 从指定路径以 desktop 语义加载配置(仅配置文件和默认值),用于测试场景。 func LoadDesktopConfigAtPath(configPath string) (*Config, error) { return loadConfig(loadOptions{ configPathOverride: configPath, useCLI: false, useEnv: false, useConfigFlag: false, }) } // SaveConfig saves config to YAML file func SaveConfig(cfg *Config) error { configPath, err := GetConfigPath() if err != nil { return appErrors.Wrap(appErrors.ErrInternal, err) } data, err := yaml.Marshal(cfg) if err != nil { return appErrors.Wrap(appErrors.ErrInternal, err) } // Ensure directory exists dir := filepath.Dir(configPath) if err := os.MkdirAll(dir, 0o755); err != nil { return appErrors.Wrap(appErrors.ErrInternal, err) } return os.WriteFile(configPath, data, 0o600) } // Validate validates the config func (c *Config) Validate() error { validate := validator.New() if err := validate.Struct(c); err != nil { return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("配置验证失败: %v", err)) } return nil } // PrintSummary 打印配置摘要 func (c *Config) PrintSummary(logger *zap.Logger) { logger.Info("AI Gateway 启动配置", zap.Int("server_port", c.Server.Port), zap.String("database_driver", c.Database.Driver), zap.String("log_level", c.Log.Level), ) if c.Database.Driver == "mysql" { logger.Info("数据库配置", zap.String("driver", "mysql"), zap.String("host", c.Database.Host), zap.Int("port", c.Database.Port), zap.String("database", c.Database.DBName), ) } else { logger.Info("数据库配置", zap.String("driver", "sqlite"), zap.String("path", c.Database.Path), ) } }