1
0

feat: 实现分层架构,包含 domain、service、repository 和 pkg 层

- 新增 domain 层:model、provider、route、stats 实体
- 新增 service 层:models、providers、routing、stats 业务逻辑
- 新增 repository 层:models、providers、stats 数据访问
- 新增 pkg 工具包:errors、logger、validator
- 新增中间件:CORS、logging、recovery、request ID
- 新增数据库迁移:初始 schema 和索引
- 新增单元测试和集成测试
- 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage
- 移除 config 子包和 model_router(已迁移至分层架构)
This commit is contained in:
2026-04-16 00:47:20 +08:00
parent 915b004924
commit f18904af1e
77 changed files with 5727 additions and 1257 deletions

View File

@@ -1,24 +1,87 @@
package config
import (
"fmt"
"os"
"path/filepath"
"time"
"gopkg.in/yaml.v3"
appErrors "nex/backend/pkg/errors"
)
// Config 应用配置
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Log LogConfig `yaml:"log"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Port int `yaml:"port"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Path string `yaml:"path"`
MaxIdleConns int `yaml:"max_idle_conns"`
MaxOpenConns int `yaml:"max_open_conns"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"`
}
// LogConfig 日志配置
type LogConfig struct {
Level string `yaml:"level"`
Path string `yaml:"path"`
MaxSize int `yaml:"max_size"`
MaxBackups int `yaml:"max_backups"`
MaxAge int `yaml:"max_age"`
Compress bool `yaml:"compress"`
}
// DefaultConfig returns default config values
func DefaultConfig() *Config {
// Use home dir for default paths
homeDir, _ := os.UserHomeDir()
nexDir := filepath.Join(homeDir, ".nex")
return &Config{
Server: ServerConfig{
Port: 9826,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
},
Database: DatabaseConfig{
Path: filepath.Join(nexDir, "config.db"),
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 1 * time.Hour,
},
Log: LogConfig{
Level: "info",
Path: filepath.Join(nexDir, "log"),
MaxSize: 100,
MaxBackups: 10,
MaxAge: 30,
Compress: true,
},
}
}
// GetConfigDir 获取配置目录路径(~/.nex/
func GetConfigDir() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
// 确保目录存在
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
}
return configDir, nil
}
@@ -30,3 +93,79 @@ func GetDBPath() (string, error) {
}
return filepath.Join(configDir, "config.db"), nil
}
// GetConfigPath 获取配置文件路径
func GetConfigPath() (string, error) {
configDir, err := GetConfigDir()
if err != nil {
return "", err
}
return filepath.Join(configDir, "config.yaml"), nil
}
// LoadConfig loads config from YAML file, creates default if not exists
func LoadConfig() (*Config, error) {
configPath, err := GetConfigPath()
if err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
cfg := DefaultConfig()
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
// Create default config file
if saveErr := SaveConfig(cfg); saveErr != nil {
return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败")
}
return cfg, nil
}
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
return cfg, nil
}
// SaveConfig saves config to YAML file
func SaveConfig(cfg *Config) error {
configPath, err := GetConfigPath()
if err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
data, err := yaml.Marshal(cfg)
if err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
// Ensure directory exists
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return os.WriteFile(configPath, data, 0644)
}
// Validate validates the config
func (c *Config) Validate() error {
if c.Server.Port < 1 || c.Server.Port > 65535 {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port))
}
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
if !validLevels[c.Log.Level] {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level))
}
if c.Database.Path == "" {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空")
}
return nil
}

View File

@@ -0,0 +1,176 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
require.NotNil(t, cfg)
assert.Equal(t, 9826, cfg.Server.Port)
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
assert.Equal(t, "info", cfg.Log.Level)
assert.Equal(t, 100, cfg.Log.MaxSize)
assert.Equal(t, 10, cfg.Log.MaxBackups)
assert.Equal(t, 30, cfg.Log.MaxAge)
assert.Equal(t, true, cfg.Log.Compress)
}
func TestConfig_Validate(t *testing.T) {
tests := []struct {
name string
modify func(*Config)
wantErr bool
errMsg string
}{
{
name: "默认配置有效",
modify: func(c *Config) {},
wantErr: false,
},
{
name: "端口号为0无效",
modify: func(c *Config) { c.Server.Port = 0 },
wantErr: true,
errMsg: "无效的端口号",
},
{
name: "端口号超出范围无效",
modify: func(c *Config) { c.Server.Port = 70000 },
wantErr: true,
errMsg: "无效的端口号",
},
{
name: "端口号为1有效",
modify: func(c *Config) { c.Server.Port = 1 },
wantErr: false,
},
{
name: "端口号为65535有效",
modify: func(c *Config) { c.Server.Port = 65535 },
wantErr: false,
},
{
name: "无效日志级别",
modify: func(c *Config) { c.Log.Level = "invalid" },
wantErr: true,
errMsg: "无效的日志级别",
},
{
name: "debug级别有效",
modify: func(c *Config) { c.Log.Level = "debug" },
wantErr: false,
},
{
name: "warn级别有效",
modify: func(c *Config) { c.Log.Level = "warn" },
wantErr: false,
},
{
name: "error级别有效",
modify: func(c *Config) { c.Log.Level = "error" },
wantErr: false,
},
{
name: "数据库路径为空无效",
modify: func(c *Config) { c.Database.Path = "" },
wantErr: true,
errMsg: "数据库路径不能为空",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := DefaultConfig()
tt.modify(cfg)
err := cfg.Validate()
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}
func TestGetConfigDir(t *testing.T) {
dir, err := GetConfigDir()
require.NoError(t, err)
assert.NotEmpty(t, dir)
assert.Contains(t, dir, ".nex")
}
func TestGetDBPath(t *testing.T) {
path, err := GetDBPath()
require.NoError(t, err)
assert.NotEmpty(t, path)
assert.Contains(t, path, "config.db")
}
func TestGetConfigPath(t *testing.T) {
path, err := GetConfigPath()
require.NoError(t, err)
assert.NotEmpty(t, path)
assert.Contains(t, path, "config.yaml")
}
func TestSaveAndLoadConfig(t *testing.T) {
// 使用临时目录覆盖配置路径
dir := t.TempDir()
cfg := &Config{
Server: ServerConfig{
Port: 9999,
ReadTimeout: 10 * time.Second,
WriteTimeout: 20 * time.Second,
},
Database: DatabaseConfig{
Path: filepath.Join(dir, "test.db"),
MaxIdleConns: 5,
MaxOpenConns: 50,
ConnMaxLifetime: 30 * time.Minute,
},
Log: LogConfig{
Level: "debug",
Path: filepath.Join(dir, "log"),
MaxSize: 50,
MaxBackups: 5,
MaxAge: 7,
Compress: false,
},
}
// 保存配置
configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg)
require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644)
require.NoError(t, err)
// 加载配置
data, err = os.ReadFile(configPath)
require.NoError(t, err)
loaded := &Config{}
err = yaml.Unmarshal(data, loaded)
require.NoError(t, err)
assert.Equal(t, cfg.Server.Port, loaded.Server.Port)
assert.Equal(t, cfg.Log.Level, loaded.Log.Level)
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
}

View File

@@ -1,58 +0,0 @@
package config
import (
"fmt"
"log"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var db *gorm.DB
// InitDB 初始化数据库连接并创建表
func InitDB() error {
dbPath, err := GetDBPath()
if err != nil {
return fmt.Errorf("获取数据库路径失败: %w", err)
}
// 打开数据库连接
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
// 启用 WAL 模式以提升并发性能
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
// 自动迁移表结构
if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil {
return fmt.Errorf("创建表失败: %w", err)
}
log.Printf("数据库初始化成功: %s", dbPath)
return nil
}
// GetDB 获取数据库连接
func GetDB() *gorm.DB {
return db
}
// CloseDB 关闭数据库连接
func CloseDB() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}

View File

@@ -1,119 +0,0 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateModel 创建模型
func CreateModel(model *Model) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 验证供应商是否存在
var provider Provider
err := db.First(&provider, "id = ?", model.ProviderID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
model.CreatedAt = time.Now()
return db.Create(model).Error
}
// GetModel 获取模型
func GetModel(id string) (*Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var model Model
err := db.First(&model, "id = ?", id).Error
if err != nil {
return nil, err
}
return &model, nil
}
// ListModels 列出模型
func ListModels(providerID string) ([]Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var models []Model
var err error
if providerID != "" {
err = db.Where("provider_id = ?", providerID).Find(&models).Error
} else {
err = db.Find(&models).Error
}
if err != nil {
return nil, err
}
return models, nil
}
// UpdateModel 更新模型
func UpdateModel(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 如果更新了 provider_id验证新供应商是否存在
if providerID, ok := updates["provider_id"].(string); ok {
var provider Provider
err := db.First(&provider, "id = ?", providerID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
}
result := db.Model(&Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteModel 删除模型
func DeleteModel(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Model{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -1,102 +0,0 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateProvider 创建供应商
func CreateProvider(provider *Provider) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
provider.CreatedAt = time.Now()
provider.UpdatedAt = time.Now()
return db.Create(provider).Error
}
// GetProvider 获取供应商
func GetProvider(id string, maskKey bool) (*Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var provider Provider
err := db.First(&provider, "id = ?", id).Error
if err != nil {
return nil, err
}
if maskKey {
provider.MaskAPIKey()
}
return &provider, nil
}
// ListProviders 列出所有供应商
func ListProviders() ([]Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var providers []Provider
err := db.Find(&providers).Error
if err != nil {
return nil, err
}
// 掩码所有 API Key
for i := range providers {
providers[i].MaskAPIKey()
}
return providers, nil
}
// UpdateProvider 更新供应商
func UpdateProvider(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
updates["updated_at"] = time.Now()
result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteProvider 删除供应商
func DeleteProvider(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Provider{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -1,79 +0,0 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// RecordRequest 记录请求统计
func RecordRequest(providerID, modelName string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
// 使用事务确保并发安全
return db.Transaction(func(tx *gorm.DB) error {
var stats UsageStats
// 查找或创建统计记录
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).
First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// 创建新记录
stats = UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
}
// 更新计数
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
}
// GetStats 查询统计
func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var stats []UsageStats
query := db.Model(&UsageStats{})
if providerID != "" {
query = query.Where("provider_id = ?", providerID)
}
if modelName != "" {
query = query.Where("model_name = ?", modelName)
}
if startDate != nil {
query = query.Where("date >= ?", startDate)
}
if endDate != nil {
query = query.Where("date <= ?", endDate)
}
err := query.Order("date DESC").Find(&stats).Error
if err != nil {
return nil, err
}
return stats, nil
}