1
0
Files
nex/backend/internal/config/config.go
lanyuanxiaoyao 5513f0c13d feat: 区分 server 与 desktop 配置加载入口,取消自动创建配置文件
- config.go 重构:抽取 loadConfig 共享逻辑,新增 LoadServerConfig/LoadDesktopConfig/LoadDesktopConfigAtPath,LoadConfig 保持向后兼容
- setupConfigFile 移除 SafeWriteConfigAs 自动创建逻辑,文件不存在时仅使用默认值
- cmd/desktop 切换为 LoadDesktopConfig,端口/HTTP/浏览器/托盘统一使用 cfg.Server.Port
- cmd/server 显式使用 LoadServerConfig 明确入口语义
- 提取 desktop 可测 helper:desktopListenAddr/desktopURL/desktopPortMenuTitle/desktopConfigErrorMessage
- 新增测试:desktop 忽略 CLI/env/未知参数、配置快照不变、无效配置文件不静默回退、端口 helper 一致性
- README 区分 server/desktop 配置源,移除首次启动自动创建配置文件描述
- 同步 delta specs 到 openspec/specs/ 主规范
2026-05-06 11:59:19 +08:00

414 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/go-playground/validator/v10"
"github.com/mitchellh/mapstructure"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
appErrors "nex/backend/pkg/errors"
)
// Config 应用配置
type Config struct {
Server ServerConfig `yaml:"server" mapstructure:"server" validate:"required"`
Database DatabaseConfig `yaml:"database" mapstructure:"database" validate:"required"`
Log LogConfig `yaml:"log" mapstructure:"log" validate:"required"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Port int `yaml:"port" mapstructure:"port" validate:"required,min=1,max=65535"`
ReadTimeout time.Duration `yaml:"read_timeout" mapstructure:"read_timeout" validate:"required"`
WriteTimeout time.Duration `yaml:"write_timeout" mapstructure:"write_timeout" validate:"required"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"`
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
Password string `yaml:"password" mapstructure:"password"`
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
}
// LogConfig 日志配置
type LogConfig struct {
Level string `yaml:"level" mapstructure:"level" validate:"required,oneof=debug info warn error"`
Path string `yaml:"path" mapstructure:"path" validate:"required"`
MaxSize int `yaml:"max_size" mapstructure:"max_size" validate:"required,min=1"`
MaxBackups int `yaml:"max_backups" mapstructure:"max_backups" validate:"required,min=0"`
MaxAge int `yaml:"max_age" mapstructure:"max_age" validate:"required,min=0"`
Compress bool `yaml:"compress" mapstructure:"compress"`
}
// DefaultConfig returns default config values
func DefaultConfig() *Config {
// Use home dir for default paths
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
return &Config{
Server: ServerConfig{
Port: 9826,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
},
Database: DatabaseConfig{
Driver: "sqlite",
Path: filepath.Join(nexDir, "config.db"),
Host: "",
Port: 3306,
User: "",
Password: "",
DBName: "nex",
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 1 * time.Hour,
},
Log: LogConfig{
Level: "info",
Path: filepath.Join(nexDir, "log"),
MaxSize: 100,
MaxBackups: 10,
MaxAge: 30,
Compress: true,
},
}
}
// GetConfigDir 获取配置目录路径(~/.nex/
func GetConfigDir() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
}
// GetDBPath 获取数据库文件路径
func GetDBPath() (string, error) {
configDir, err := GetConfigDir()
if err != nil {
return "", err
}
return filepath.Join(configDir, "config.db"), nil
}
// GetConfigPath 获取配置文件路径
func GetConfigPath() (string, error) {
configDir, err := GetConfigDir()
if err != nil {
return "", err
}
return filepath.Join(configDir, "config.yaml"), nil
}
// setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) {
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826)
v.SetDefault("server.read_timeout", "30s")
v.SetDefault("server.write_timeout", "30s")
v.SetDefault("database.driver", "sqlite")
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
v.SetDefault("database.host", "")
v.SetDefault("database.port", 3306)
v.SetDefault("database.user", "")
v.SetDefault("database.password", "")
v.SetDefault("database.dbname", "nex")
v.SetDefault("database.max_idle_conns", 10)
v.SetDefault("database.max_open_conns", 100)
v.SetDefault("database.conn_max_lifetime", "1h")
v.SetDefault("log.level", "info")
v.SetDefault("log.path", filepath.Join(nexDir, "log"))
v.SetDefault("log.max_size", 100)
v.SetDefault("log.max_backups", 10)
v.SetDefault("log.max_age", 30)
v.SetDefault("log.compress", true)
}
// setupFlags 定义和绑定 CLI 参数
func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 定义所有配置项的 CLI 参数
// 注意:这里不设置默认值,让 viper 的默认值生效
flagSet.Int("server-port", 0, "服务器端口")
flagSet.Duration("server-read-timeout", 0, "读超时")
flagSet.Duration("server-write-timeout", 0, "写超时")
flagSet.String("database-driver", "", "数据库驱动sqlite/mysql")
flagSet.String("database-path", "", "数据库文件路径")
flagSet.String("database-host", "", "MySQL 主机地址")
flagSet.Int("database-port", 0, "MySQL 端口")
flagSet.String("database-user", "", "MySQL 用户名")
flagSet.String("database-password", "", "MySQL 密码")
flagSet.String("database-dbname", "", "MySQL 数据库名")
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
flagSet.String("log-level", "", "日志级别debug/info/warn/error")
flagSet.String("log-path", "", "日志文件目录")
flagSet.Int("log-max-size", 0, "单个日志文件最大大小 MB")
flagSet.Int("log-max-backups", 0, "保留的旧日志文件最大数量")
flagSet.Int("log-max-age", 0, "保留旧日志文件的最大天数")
flagSet.Bool("log-compress", false, "是否压缩旧日志文件")
// 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
}
// setupEnv 绑定环境变量
func setupEnv(v *viper.Viper) {
v.SetEnvPrefix("NEX")
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
}
// setupConfigFile 读取配置文件
func setupConfigFile(v *viper.Viper, configPath string) error {
v.SetConfigFile(configPath)
v.SetConfigType("yaml")
if err := v.ReadInConfig(); err != nil {
if os.IsNotExist(err) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return nil
}
// loadOptions 控制配置加载器行为
type loadOptions struct {
configPathOverride string
useCLI bool
useEnv bool
useConfigFlag bool
}
// resolveConfigPath 根据 loadOptions 解析 CLI 参数并返回最终配置文件路径
func resolveConfigPath(v *viper.Viper, opts loadOptions) (string, error) {
configPath := opts.configPathOverride
if !opts.useCLI && !opts.useConfigFlag {
return configPath, nil
}
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
if opts.useConfigFlag {
flagSet.String("config", opts.configPathOverride, "配置文件路径")
}
if opts.useCLI {
setupFlags(v, flagSet)
}
if err := flagSet.Parse(os.Args[1:]); err != nil {
return "", appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
if opts.useConfigFlag {
if f, err := flagSet.GetString("config"); err == nil && f != "" {
configPath = f
}
}
return configPath, nil
}
// loadConfig 共享配置加载逻辑,通过 loadOptions 控制是否启用 CLI、环境变量和 --config 覆盖
func loadConfig(opts loadOptions) (*Config, error) {
v := viper.New()
setupDefaults(v)
configPath, err := resolveConfigPath(v, opts)
if err != nil {
return nil, err
}
if opts.useEnv {
setupEnv(v)
}
if err := setupConfigFile(v, configPath); err != nil {
return nil, err
}
cfg := &Config{}
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
))); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
if err := cfg.Validate(); err != nil {
return nil, err
}
return cfg, nil
}
// LoadServerConfig 为 server 入口加载配置,支持 CLI 参数、环境变量和 --config
func LoadServerConfig() (*Config, error) {
configPath, err := GetConfigPath()
if err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
return loadConfig(loadOptions{
configPathOverride: configPath,
useCLI: true,
useEnv: true,
useConfigFlag: true,
})
}
// LoadDesktopConfig 为 desktop 入口加载配置,固定使用默认配置文件,不支持 CLI、环境变量和 --config
func LoadDesktopConfig() (*Config, error) {
configPath, err := GetConfigPath()
if err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
return loadConfig(loadOptions{
configPathOverride: configPath,
useCLI: false,
useEnv: false,
useConfigFlag: false,
})
}
// LoadConfig loads config from YAML file.
// 向后兼容,等同于 LoadServerConfig。
func LoadConfig() (*Config, error) {
return LoadServerConfig()
}
// LoadConfigFromPath 从指定路径加载配置。
// 保留向后兼容,沿用 server 语义(支持 CLI、env 和 --config 覆盖)。
func LoadConfigFromPath(configPath string) (*Config, error) {
return loadConfig(loadOptions{
configPathOverride: configPath,
useCLI: true,
useEnv: true,
useConfigFlag: true,
})
}
// LoadDesktopConfigAtPath 从指定路径以 desktop 语义加载配置(仅配置文件和默认值),用于测试场景。
func LoadDesktopConfigAtPath(configPath string) (*Config, error) {
return loadConfig(loadOptions{
configPathOverride: configPath,
useCLI: false,
useEnv: false,
useConfigFlag: false,
})
}
// SaveConfig saves config to YAML file
func SaveConfig(cfg *Config) error {
configPath, err := GetConfigPath()
if err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
data, err := yaml.Marshal(cfg)
if err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
// Ensure directory exists
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return os.WriteFile(configPath, data, 0o600)
}
// Validate validates the config
func (c *Config) Validate() error {
validate := validator.New()
if err := validate.Struct(c); err != nil {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("配置验证失败: %v", err))
}
return nil
}
// PrintSummary 打印配置摘要
func (c *Config) PrintSummary(logger *zap.Logger) {
logger.Info("AI Gateway 启动配置",
zap.Int("server_port", c.Server.Port),
zap.String("database_driver", c.Database.Driver),
zap.String("log_level", c.Log.Level),
)
if c.Database.Driver == "mysql" {
logger.Info("数据库配置",
zap.String("driver", "mysql"),
zap.String("host", c.Database.Host),
zap.Int("port", c.Database.Port),
zap.String("database", c.Database.DBName),
)
} else {
logger.Info("数据库配置",
zap.String("driver", "sqlite"),
zap.String("path", c.Database.Path),
)
}
}