- 修复 viper SafeWriteConfig 与 SetConfigFile 不兼容问题 - 将 SafeWriteConfig() 替换为 SafeWriteConfigAs(configPath) - 绕过 viper 的 configPaths 检查 - 调整 Makefile 测试命令分类 - backend-test: 仅运行后端核心测试 - backend-test-all: 运行全部后端测试(含 desktop) - desktop-test: 单独运行桌面应用测试 - 同步 config-management 和 test-coverage 规范
358 lines
12 KiB
Go
358 lines
12 KiB
Go
package config
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/go-playground/validator/v10"
|
||
"github.com/mitchellh/mapstructure"
|
||
"github.com/spf13/pflag"
|
||
"github.com/spf13/viper"
|
||
"go.uber.org/zap"
|
||
"gopkg.in/yaml.v3"
|
||
|
||
appErrors "nex/backend/pkg/errors"
|
||
)
|
||
|
||
// Config 应用配置
|
||
type Config struct {
|
||
Server ServerConfig `yaml:"server" mapstructure:"server" validate:"required"`
|
||
Database DatabaseConfig `yaml:"database" mapstructure:"database" validate:"required"`
|
||
Log LogConfig `yaml:"log" mapstructure:"log" validate:"required"`
|
||
}
|
||
|
||
// ServerConfig 服务器配置
|
||
type ServerConfig struct {
|
||
Port int `yaml:"port" mapstructure:"port" validate:"required,min=1,max=65535"`
|
||
ReadTimeout time.Duration `yaml:"read_timeout" mapstructure:"read_timeout" validate:"required"`
|
||
WriteTimeout time.Duration `yaml:"write_timeout" mapstructure:"write_timeout" validate:"required"`
|
||
}
|
||
|
||
// DatabaseConfig 数据库配置
|
||
type DatabaseConfig struct {
|
||
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
|
||
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
|
||
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
|
||
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"`
|
||
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
|
||
Password string `yaml:"password" mapstructure:"password"`
|
||
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
|
||
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
|
||
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
|
||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
|
||
}
|
||
|
||
// LogConfig 日志配置
|
||
type LogConfig struct {
|
||
Level string `yaml:"level" mapstructure:"level" validate:"required,oneof=debug info warn error"`
|
||
Path string `yaml:"path" mapstructure:"path" validate:"required"`
|
||
MaxSize int `yaml:"max_size" mapstructure:"max_size" validate:"required,min=1"`
|
||
MaxBackups int `yaml:"max_backups" mapstructure:"max_backups" validate:"required,min=0"`
|
||
MaxAge int `yaml:"max_age" mapstructure:"max_age" validate:"required,min=0"`
|
||
Compress bool `yaml:"compress" mapstructure:"compress"`
|
||
}
|
||
|
||
// DefaultConfig returns default config values
|
||
func DefaultConfig() *Config {
|
||
// Use home dir for default paths
|
||
homeDir, err := os.UserHomeDir()
|
||
if err != nil {
|
||
homeDir = "."
|
||
}
|
||
nexDir := filepath.Join(homeDir, ".nex")
|
||
|
||
return &Config{
|
||
Server: ServerConfig{
|
||
Port: 9826,
|
||
ReadTimeout: 30 * time.Second,
|
||
WriteTimeout: 30 * time.Second,
|
||
},
|
||
Database: DatabaseConfig{
|
||
Driver: "sqlite",
|
||
Path: filepath.Join(nexDir, "config.db"),
|
||
Host: "",
|
||
Port: 3306,
|
||
User: "",
|
||
Password: "",
|
||
DBName: "nex",
|
||
MaxIdleConns: 10,
|
||
MaxOpenConns: 100,
|
||
ConnMaxLifetime: 1 * time.Hour,
|
||
},
|
||
Log: LogConfig{
|
||
Level: "info",
|
||
Path: filepath.Join(nexDir, "log"),
|
||
MaxSize: 100,
|
||
MaxBackups: 10,
|
||
MaxAge: 30,
|
||
Compress: true,
|
||
},
|
||
}
|
||
}
|
||
|
||
// GetConfigDir 获取配置目录路径(~/.nex/)
|
||
func GetConfigDir() (string, error) {
|
||
homeDir, err := os.UserHomeDir()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
configDir := filepath.Join(homeDir, ".nex")
|
||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||
return "", err
|
||
}
|
||
return configDir, nil
|
||
}
|
||
|
||
// GetDBPath 获取数据库文件路径
|
||
func GetDBPath() (string, error) {
|
||
configDir, err := GetConfigDir()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return filepath.Join(configDir, "config.db"), nil
|
||
}
|
||
|
||
// GetConfigPath 获取配置文件路径
|
||
func GetConfigPath() (string, error) {
|
||
configDir, err := GetConfigDir()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return filepath.Join(configDir, "config.yaml"), nil
|
||
}
|
||
|
||
// setupDefaults 设置默认配置值
|
||
func setupDefaults(v *viper.Viper) {
|
||
homeDir, err := os.UserHomeDir()
|
||
if err != nil {
|
||
homeDir = "."
|
||
}
|
||
nexDir := filepath.Join(homeDir, ".nex")
|
||
|
||
v.SetDefault("server.port", 9826)
|
||
v.SetDefault("server.read_timeout", "30s")
|
||
v.SetDefault("server.write_timeout", "30s")
|
||
|
||
v.SetDefault("database.driver", "sqlite")
|
||
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
|
||
v.SetDefault("database.host", "")
|
||
v.SetDefault("database.port", 3306)
|
||
v.SetDefault("database.user", "")
|
||
v.SetDefault("database.password", "")
|
||
v.SetDefault("database.dbname", "nex")
|
||
v.SetDefault("database.max_idle_conns", 10)
|
||
v.SetDefault("database.max_open_conns", 100)
|
||
v.SetDefault("database.conn_max_lifetime", "1h")
|
||
|
||
v.SetDefault("log.level", "info")
|
||
v.SetDefault("log.path", filepath.Join(nexDir, "log"))
|
||
v.SetDefault("log.max_size", 100)
|
||
v.SetDefault("log.max_backups", 10)
|
||
v.SetDefault("log.max_age", 30)
|
||
v.SetDefault("log.compress", true)
|
||
}
|
||
|
||
// setupFlags 定义和绑定 CLI 参数
|
||
func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
||
// 定义所有配置项的 CLI 参数
|
||
// 注意:这里不设置默认值,让 viper 的默认值生效
|
||
flagSet.Int("server-port", 0, "服务器端口")
|
||
flagSet.Duration("server-read-timeout", 0, "读超时")
|
||
flagSet.Duration("server-write-timeout", 0, "写超时")
|
||
|
||
flagSet.String("database-driver", "", "数据库驱动:sqlite/mysql")
|
||
flagSet.String("database-path", "", "数据库文件路径")
|
||
flagSet.String("database-host", "", "MySQL 主机地址")
|
||
flagSet.Int("database-port", 0, "MySQL 端口")
|
||
flagSet.String("database-user", "", "MySQL 用户名")
|
||
flagSet.String("database-password", "", "MySQL 密码")
|
||
flagSet.String("database-dbname", "", "MySQL 数据库名")
|
||
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
|
||
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
|
||
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
|
||
|
||
flagSet.String("log-level", "", "日志级别:debug/info/warn/error")
|
||
flagSet.String("log-path", "", "日志文件目录")
|
||
flagSet.Int("log-max-size", 0, "单个日志文件最大大小 MB")
|
||
flagSet.Int("log-max-backups", 0, "保留的旧日志文件最大数量")
|
||
flagSet.Int("log-max-age", 0, "保留旧日志文件的最大天数")
|
||
flagSet.Bool("log-compress", false, "是否压缩旧日志文件")
|
||
|
||
// 绑定所有 flag 到 viper
|
||
// 注意:必须在设置默认值之后绑定
|
||
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
|
||
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||
|
||
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
|
||
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
|
||
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
|
||
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
|
||
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
|
||
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
|
||
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
|
||
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||
|
||
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
|
||
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
|
||
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
|
||
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
|
||
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
|
||
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
|
||
}
|
||
|
||
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
|
||
if err := v.BindPFlag(key, flag); err != nil {
|
||
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
|
||
}
|
||
}
|
||
|
||
// setupEnv 绑定环境变量
|
||
func setupEnv(v *viper.Viper) {
|
||
v.SetEnvPrefix("NEX")
|
||
v.AutomaticEnv()
|
||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||
}
|
||
|
||
// setupConfigFile 读取配置文件
|
||
func setupConfigFile(v *viper.Viper, configPath string) error {
|
||
v.SetConfigFile(configPath)
|
||
v.SetConfigType("yaml")
|
||
|
||
// 尝试读取配置文件,如果不存在则忽略
|
||
if err := v.ReadInConfig(); err != nil {
|
||
if !os.IsNotExist(err) {
|
||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||
}
|
||
// 配置文件不存在,创建默认配置文件
|
||
writeErr := v.SafeWriteConfigAs(configPath)
|
||
if writeErr == nil {
|
||
return nil
|
||
}
|
||
|
||
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
|
||
if errors.As(writeErr, &alreadyExistsErr) {
|
||
return nil
|
||
}
|
||
|
||
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// LoadConfig loads config from YAML file, creates default if not exists
|
||
func LoadConfig() (*Config, error) {
|
||
configPath, err := GetConfigPath()
|
||
if err != nil {
|
||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||
}
|
||
return LoadConfigFromPath(configPath)
|
||
}
|
||
|
||
// LoadConfigFromPath 从指定路径加载配置
|
||
func LoadConfigFromPath(configPath string) (*Config, error) {
|
||
// 1. 创建 Viper 实例
|
||
v := viper.New()
|
||
|
||
// 2. 定义 CLI 参数
|
||
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
|
||
flagSet.String("config", configPath, "配置文件路径")
|
||
setupFlags(v, flagSet)
|
||
|
||
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
|
||
if err := flagSet.Parse(os.Args[1:]); err != nil {
|
||
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
|
||
}
|
||
|
||
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
|
||
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
|
||
configPath = configPathFlag
|
||
}
|
||
|
||
// 5. 设置默认值
|
||
setupDefaults(v)
|
||
|
||
// 6. 绑定环境变量
|
||
setupEnv(v)
|
||
|
||
// 7. 读取配置文件
|
||
if err := setupConfigFile(v, configPath); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 8. 反序列化到结构体
|
||
cfg := &Config{}
|
||
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||
mapstructure.StringToTimeDurationHookFunc(),
|
||
mapstructure.StringToSliceHookFunc(","),
|
||
))); err != nil {
|
||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||
}
|
||
|
||
// 9. 验证配置
|
||
if err := cfg.Validate(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return cfg, nil
|
||
}
|
||
|
||
// SaveConfig saves config to YAML file
|
||
func SaveConfig(cfg *Config) error {
|
||
configPath, err := GetConfigPath()
|
||
if err != nil {
|
||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||
}
|
||
|
||
data, err := yaml.Marshal(cfg)
|
||
if err != nil {
|
||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||
}
|
||
|
||
// Ensure directory exists
|
||
dir := filepath.Dir(configPath)
|
||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||
}
|
||
|
||
return os.WriteFile(configPath, data, 0o600)
|
||
}
|
||
|
||
// Validate validates the config
|
||
func (c *Config) Validate() error {
|
||
validate := validator.New()
|
||
if err := validate.Struct(c); err != nil {
|
||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("配置验证失败: %v", err))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// PrintSummary 打印配置摘要
|
||
func (c *Config) PrintSummary(logger *zap.Logger) {
|
||
logger.Info("AI Gateway 启动配置",
|
||
zap.Int("server_port", c.Server.Port),
|
||
zap.String("database_driver", c.Database.Driver),
|
||
zap.String("log_level", c.Log.Level),
|
||
)
|
||
|
||
if c.Database.Driver == "mysql" {
|
||
logger.Info("数据库配置",
|
||
zap.String("driver", "mysql"),
|
||
zap.String("host", c.Database.Host),
|
||
zap.Int("port", c.Database.Port),
|
||
zap.String("database", c.Database.DBName),
|
||
)
|
||
} else {
|
||
logger.Info("数据库配置",
|
||
zap.String("driver", "sqlite"),
|
||
zap.String("path", c.Database.Path),
|
||
)
|
||
}
|
||
}
|