package config import ( "fmt" "os" "path/filepath" "time" "gopkg.in/yaml.v3" appErrors "nex/backend/pkg/errors" ) // Config 应用配置 type Config struct { Server ServerConfig `yaml:"server"` Database DatabaseConfig `yaml:"database"` Log LogConfig `yaml:"log"` } // ServerConfig 服务器配置 type ServerConfig struct { Port int `yaml:"port"` ReadTimeout time.Duration `yaml:"read_timeout"` WriteTimeout time.Duration `yaml:"write_timeout"` } // DatabaseConfig 数据库配置 type DatabaseConfig struct { Path string `yaml:"path"` MaxIdleConns int `yaml:"max_idle_conns"` MaxOpenConns int `yaml:"max_open_conns"` ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"` } // LogConfig 日志配置 type LogConfig struct { Level string `yaml:"level"` Path string `yaml:"path"` MaxSize int `yaml:"max_size"` MaxBackups int `yaml:"max_backups"` MaxAge int `yaml:"max_age"` Compress bool `yaml:"compress"` } // DefaultConfig returns default config values func DefaultConfig() *Config { // Use home dir for default paths homeDir, _ := os.UserHomeDir() nexDir := filepath.Join(homeDir, ".nex") return &Config{ Server: ServerConfig{ Port: 9826, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, }, Database: DatabaseConfig{ Path: filepath.Join(nexDir, "config.db"), MaxIdleConns: 10, MaxOpenConns: 100, ConnMaxLifetime: 1 * time.Hour, }, Log: LogConfig{ Level: "info", Path: filepath.Join(nexDir, "log"), MaxSize: 100, MaxBackups: 10, MaxAge: 30, Compress: true, }, } } // GetConfigDir 获取配置目录路径(~/.nex/) func GetConfigDir() (string, error) { homeDir, err := os.UserHomeDir() if err != nil { return "", err } configDir := filepath.Join(homeDir, ".nex") if err := os.MkdirAll(configDir, 0755); err != nil { return "", err } return configDir, nil } // GetDBPath 获取数据库文件路径 func GetDBPath() (string, error) { configDir, err := GetConfigDir() if err != nil { return "", err } return filepath.Join(configDir, "config.db"), nil } // GetConfigPath 获取配置文件路径 func GetConfigPath() (string, error) { configDir, err := GetConfigDir() if err != nil { return "", err } return filepath.Join(configDir, "config.yaml"), nil } // LoadConfig loads config from YAML file, creates default if not exists func LoadConfig() (*Config, error) { configPath, err := GetConfigPath() if err != nil { return nil, appErrors.Wrap(appErrors.ErrInternal, err) } cfg := DefaultConfig() data, err := os.ReadFile(configPath) if err != nil { if os.IsNotExist(err) { // Create default config file if saveErr := SaveConfig(cfg); saveErr != nil { return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败") } return cfg, nil } return nil, appErrors.Wrap(appErrors.ErrInternal, err) } if err := yaml.Unmarshal(data, cfg); err != nil { return nil, appErrors.Wrap(appErrors.ErrInternal, err) } return cfg, nil } // SaveConfig saves config to YAML file func SaveConfig(cfg *Config) error { configPath, err := GetConfigPath() if err != nil { return appErrors.Wrap(appErrors.ErrInternal, err) } data, err := yaml.Marshal(cfg) if err != nil { return appErrors.Wrap(appErrors.ErrInternal, err) } // Ensure directory exists dir := filepath.Dir(configPath) if err := os.MkdirAll(dir, 0755); err != nil { return appErrors.Wrap(appErrors.ErrInternal, err) } return os.WriteFile(configPath, data, 0644) } // Validate validates the config func (c *Config) Validate() error { if c.Server.Port < 1 || c.Server.Port > 65535 { return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port)) } validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true} if !validLevels[c.Log.Level] { return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level)) } if c.Database.Path == "" { return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空") } return nil }