diff --git a/backend/Makefile b/backend/Makefile new file mode 100644 index 0000000..8eefd5d --- /dev/null +++ b/backend/Makefile @@ -0,0 +1,45 @@ +.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint + +# 构建 +build: + go build -o bin/server ./cmd/server + +# 运行 +run: + go run ./cmd/server + +# 测试 +test: + go test ./... -v + +# 测试覆盖率 +test-coverage: + go test ./... -coverprofile=coverage.out + go tool cover -html=coverage.out -o coverage.html + @echo "Coverage report generated: coverage.html" + +# 清理 +clean: + rm -rf bin/ coverage.out coverage.html + +# 数据库迁移 +migrate-up: + goose -dir migrations sqlite3 $(DB_PATH) up + +migrate-down: + goose -dir migrations sqlite3 $(DB_PATH) down + +migrate-status: + goose -dir migrations sqlite3 $(DB_PATH) status + +migrate-create: + @read -p "Migration name: " name; \ + goose -dir migrations create $$name sql + +# 代码检查 +lint: + golangci-lint run ./... + +# 安装依赖 +deps: + go mod tidy diff --git a/backend/README.md b/backend/README.md index 550bedf..6488a1f 100644 --- a/backend/README.md +++ b/backend/README.md @@ -10,13 +10,21 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。 - 支持 Function Calling / Tools - 多供应商配置和路由 - 用量统计 +- 结构化日志(zap + lumberjack) +- YAML 配置管理 +- 请求验证 +- 中间件支持(请求 ID、日志、恢复、CORS) ## 技术栈 -- **语言**: Go +- **语言**: Go 1.26+ - **HTTP 框架**: Gin - **ORM**: GORM - **数据库**: SQLite +- **日志**: zap + lumberjack +- **配置**: gopkg.in/yaml.v3 +- **验证**: go-playground/validator/v10 +- **迁移**: goose ## 项目结构 @@ -24,37 +32,86 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。 backend/ ├── cmd/ │ └── server/ -│ └── main.go # 主程序入口 +│ └── main.go # 主程序入口(依赖注入) ├── internal/ -│ ├── config/ # 配置和数据库 -│ │ ├── config.go # 配置目录管理 -│ │ ├── database.go # 数据库连接 -│ │ ├── models.go # 数据模型 -│ │ ├── provider.go # 供应商 CRUD -│ │ ├── model.go # 模型 CRUD -│ │ └── stats.go # 统计记录 -│ ├── handler/ # HTTP 处理器 +│ ├── config/ # 配置管理 +│ │ ├── config.go # 配置加载/保存/验证 +│ │ └── models.go # GORM 数据模型 +│ ├── domain/ # 领域模型 +│ │ ├── provider.go +│ │ ├── model.go +│ │ ├── stats.go +│ │ └── route.go +│ ├── handler/ # HTTP 处理器 +│ │ ├── middleware/ # 中间件 +│ │ │ ├── request_id.go +│ │ │ ├── logging.go +│ │ │ ├── recovery.go +│ │ │ └── cors.go │ │ ├── openai_handler.go │ │ ├── anthropic_handler.go │ │ ├── provider_handler.go │ │ ├── model_handler.go │ │ └── stats_handler.go -│ ├── protocol/ # 协议适配器 +│ ├── protocol/ # 协议适配器 │ │ ├── openai/ -│ │ │ ├── types.go -│ │ │ └── adapter.go +│ │ │ ├── types.go # 请求/响应类型 + 验证 +│ │ │ └── adapter.go # OpenAI 协议适配 │ │ └── anthropic/ -│ │ ├── types.go -│ │ ├── converter.go -│ │ └── stream_converter.go -│ ├── provider/ # 供应商客户端 +│ │ ├── types.go # 请求/响应类型 + 验证 +│ │ ├── converter.go # 协议转换 +│ │ └── stream_converter.go # 流式转换 +│ ├── provider/ # 供应商客户端 │ │ └── client.go -│ └── router/ # 模型路由 -│ └── model_router.go +│ ├── repository/ # 数据访问层 +│ │ ├── provider_repo.go # 接口定义 +│ │ ├── provider_repo_impl.go +│ │ ├── model_repo.go +│ │ ├── model_repo_impl.go +│ │ ├── stats_repo.go +│ │ └── stats_repo_impl.go +│ └── service/ # 业务逻辑层 +│ ├── provider_service.go # 接口定义 +│ ├── provider_service_impl.go +│ ├── model_service.go +│ ├── model_service_impl.go +│ ├── routing_service.go +│ ├── routing_service_impl.go +│ ├── stats_service.go +│ └── stats_service_impl.go +├── pkg/ # 公共包 +│ ├── errors/ # 结构化错误 +│ │ ├── errors.go +│ │ └── wrap.go +│ ├── logger/ # 日志系统 +│ │ ├── logger.go +│ │ ├── rotate.go +│ │ └── context.go +│ └── validator/ # 验证器 +│ └── validator.go +├── migrations/ # 数据库迁移 +│ ├── 001_initial_schema.sql +│ └── 002_add_indexes.sql +├── tests/ # 测试 +│ ├── helpers.go +│ ├── integration/ +│ ├── unit/ +│ └── testdata/ +├── Makefile ├── go.mod └── README.md ``` +## 架构 + +采用三层架构(handler → service → repository),通过依赖注入连接: + +``` +handler(HTTP 请求处理) + → service(业务逻辑) + → repository(数据访问) +``` + ## 运行方式 ### 安装依赖 @@ -69,7 +126,59 @@ go mod download go run cmd/server/main.go ``` -服务将在端口 9826 启动。 +服务将在端口 9826 启动。首次启动会自动创建配置文件和运行数据库迁移。 + +## 配置 + +配置文件位于 `~/.nex/config.yaml`,首次启动自动生成。 + +```yaml +server: + port: 9826 + read_timeout: 30s + write_timeout: 30s + +database: + path: ~/.nex/config.db + max_idle_conns: 10 + max_open_conns: 100 + conn_max_lifetime: 1h + +log: + level: info + path: ~/.nex/log + max_size: 100 # MB + max_backups: 10 + max_age: 30 # 天 + compress: true +``` + +数据文件: +- `~/.nex/config.yaml` - 配置文件 +- `~/.nex/config.db` - SQLite 数据库 +- `~/.nex/log/` - 日志目录 + +## 测试 + +```bash +# 运行所有测试 +make test + +# 生成覆盖率报告 +make test-coverage +``` + +## 数据库迁移 + +```bash +# 使用 Makefile +make migrate-up DB_PATH=~/.nex/config.db +make migrate-down DB_PATH=~/.nex/config.db +make migrate-status DB_PATH=~/.nex/config.db + +# 或直接使用 goose +goose -dir migrations sqlite3 ~/.nex/config.db up +``` ## API 文档 @@ -169,20 +278,20 @@ POST /v1/messages - `end_date` - 结束日期(YYYY-MM-DD) - `group_by` - 聚合维度(provider/model/date) -## 配置 - -配置和数据存储在 `~/.nex/` 目录: - -- `~/.nex/config.db` - SQLite 数据库 - ## 开发 ### 构建 ```bash -go build -o ai-gateway cmd/server/main.go +make build +``` + +### 代码检查 + +```bash +make lint ``` ### 环境要求 -- Go 1.21 或更高版本 +- Go 1.26 或更高版本 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 71a388b..6492172 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -2,88 +2,217 @@ package main import ( "context" + "fmt" "log" "net/http" "os" "os/signal" + "path/filepath" + "runtime" "syscall" "time" "github.com/gin-gonic/gin" + "github.com/pressly/goose/v3" + "go.uber.org/zap" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" "nex/backend/internal/config" "nex/backend/internal/handler" + "nex/backend/internal/handler/middleware" + "nex/backend/internal/provider" + "nex/backend/internal/repository" + "nex/backend/internal/service" + pkgLogger "nex/backend/pkg/logger" ) func main() { - // 初始化数据库 - if err := config.InitDB(); err != nil { - log.Fatalf("初始化数据库失败: %v", err) + // 1. 加载配置 + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("加载配置失败: %v", err) + } + if err := cfg.Validate(); err != nil { + log.Fatalf("配置验证失败: %v", err) } - defer config.CloseDB() - // 创建 Gin 引擎 - gin.SetMode(gin.ReleaseMode) - r := gin.Default() - - // 配置 CORS - r.Use(func(c *gin.Context) { - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(204) - return - } - c.Next() + // 2. 初始化日志 + zapLogger, err := pkgLogger.New(pkgLogger.Config{ + Level: cfg.Log.Level, + Path: cfg.Log.Path, + MaxSize: cfg.Log.MaxSize, + MaxBackups: cfg.Log.MaxBackups, + MaxAge: cfg.Log.MaxAge, + Compress: cfg.Log.Compress, }) + if err != nil { + log.Fatalf("初始化日志失败: %v", err) + } + defer zapLogger.Sync() + + // 3. 初始化数据库 + db, err := initDatabase(cfg) + if err != nil { + zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error())) + } + defer closeDB(db) + + // 4. 初始化 repository 层 + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + statsRepo := repository.NewStatsRepository(db) + + // 5. 初始化 service 层 + providerService := service.NewProviderService(providerRepo) + modelService := service.NewModelService(modelRepo, providerRepo) + routingService := service.NewRoutingService(modelRepo, providerRepo) + statsService := service.NewStatsService(statsRepo) + + // 6. 初始化 provider client + providerClient := provider.NewClient() + + // 7. 初始化 handler 层 + openaiHandler := handler.NewOpenAIHandler(providerClient, routingService, statsService) + anthropicHandler := handler.NewAnthropicHandler(providerClient, routingService, statsService) + providerHandler := handler.NewProviderHandler(providerService) + modelHandler := handler.NewModelHandler(modelService) + statsHandler := handler.NewStatsHandler(statsService) + + // 8. 创建 Gin 引擎 + gin.SetMode(gin.ReleaseMode) + r := gin.New() + + // 注册中间件(按正确顺序) + r.Use(middleware.RequestID()) + r.Use(middleware.Recovery(zapLogger)) + r.Use(middleware.Logging(zapLogger)) + r.Use(middleware.CORS()) // 注册路由 - setupRoutes(r) + setupRoutes(r, openaiHandler, anthropicHandler, providerHandler, modelHandler, statsHandler) - // 创建 HTTP 服务器 + // 9. 启动服务器 srv := &http.Server{ - Addr: ":9826", - Handler: r, + Addr: formatAddr(cfg.Server.Port), + Handler: r, + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, } - // 启动服务器(在 goroutine 中) go func() { - log.Printf("AI Gateway 启动在端口 9826") + zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr)) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("服务器启动失败: %v", err) + zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) } }() - // 等待中断信号以优雅关闭服务器 + // 等待中断信号 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit - log.Println("正在关闭服务器...") - // 给服务器 5 秒时间完成当前请求 + zapLogger.Info("正在关闭服务器...") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { - log.Fatal("服务器强制关闭:", err) + zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error())) } - log.Println("服务器已关闭") + zapLogger.Info("服务器已关闭") } -// setupRoutes 配置路由 -func setupRoutes(r *gin.Engine) { +func initDatabase(cfg *config.Config) (*gorm.DB, error) { + db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) + if err != nil { + return nil, err + } + + if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { + log.Printf("警告: 启用 WAL 模式失败: %v", err) + } + + // 运行数据库迁移 + if err := runMigrations(db); err != nil { + return nil, fmt.Errorf("数据库迁移失败: %w", err) + } + + // 配置连接池 + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns) + sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns) + sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime) + + // 记录连接池状态 + log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v", + cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime) + + return db, nil +} + +// runMigrations 使用 goose 执行数据库迁移 +func runMigrations(db *gorm.DB) error { + sqlDB, err := db.DB() + if err != nil { + return err + } + + migrationsDir := getMigrationsDir() + if _, err := os.Stat(migrationsDir); os.IsNotExist(err) { + return fmt.Errorf("迁移目录不存在: %s", migrationsDir) + } + + goose.SetDialect("sqlite3") + if err := goose.Up(sqlDB, migrationsDir); err != nil { + return err + } + + return nil +} + +// getMigrationsDir 获取迁移文件目录路径 +func getMigrationsDir() string { + // 从可执行文件位置推断迁移目录 + _, filename, _, ok := runtime.Caller(0) + if ok { + // cmd/server/main.go → backend/ → backend/migrations/ + dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations") + if abs, err := filepath.Abs(dir); err == nil { + return abs + } + } + // 回退到相对路径 + return "./migrations" +} + +func closeDB(db *gorm.DB) { + sqlDB, err := db.DB() + if err != nil { + return + } + sqlDB.Close() +} + +func formatAddr(port int) string { + return fmt.Sprintf(":%d", port) +} + +func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicHandler *handler.AnthropicHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { // OpenAI 协议代理 - openaiHandler := handler.NewOpenAIHandler() r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions) // Anthropic 协议代理 - anthropicHandler := handler.NewAnthropicHandler() r.POST("/v1/messages", anthropicHandler.HandleMessages) // 供应商管理 API - providerHandler := handler.NewProviderHandler() providers := r.Group("/api/providers") { providers.GET("", providerHandler.ListProviders) @@ -94,7 +223,6 @@ func setupRoutes(r *gin.Engine) { } // 模型管理 API - modelHandler := handler.NewModelHandler() models := r.Group("/api/models") { models.GET("", modelHandler.ListModels) @@ -105,7 +233,6 @@ func setupRoutes(r *gin.Engine) { } // 统计查询 API - statsHandler := handler.NewStatsHandler() stats := r.Group("/api/stats") { stats.GET("", statsHandler.GetStats) diff --git a/backend/go.mod b/backend/go.mod index b778fb4..030cb4e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -3,16 +3,29 @@ module nex/backend go 1.26.2 require ( + github.com/gin-gonic/gin v1.12.0 + github.com/google/uuid v1.6.0 + github.com/pressly/goose/v3 v3.27.0 + github.com/stretchr/testify v1.11.1 + go.uber.org/zap v1.27.1 + gopkg.in/lumberjack.v2 v2.0.0 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.1 +) + +require ( + github.com/BurntSushi/toml v1.6.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/gabriel-vasile/mimetype v1.4.12 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.13 // indirect github.com/gin-contrib/sse v1.1.0 // indirect - github.com/gin-gonic/gin v1.12.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.30.1 // indirect + github.com/go-playground/validator/v10 v10.30.2 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.19.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect @@ -22,20 +35,25 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/mfridman/interpolate v0.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect + github.com/sethvargo/go-retry v0.3.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/arch v0.22.0 // indirect - golang.org/x/crypto v0.48.0 // indirect + golang.org/x/crypto v0.49.0 // indirect golang.org/x/net v0.51.0 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/text v0.34.0 // indirect - google.golang.org/protobuf v1.36.10 // indirect - gorm.io/driver/sqlite v1.6.0 // indirect - gorm.io/gorm v1.31.1 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + google.golang.org/protobuf v1.36.11 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index e074724..2f4a473 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,3 +1,5 @@ +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= @@ -7,24 +9,37 @@ github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCc github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= +github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= +github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ= +github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -33,24 +48,41 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= +github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pressly/goose/v3 v3.27.0 h1:/D30gVTuQhu0WsNZYbJi4DMOsx1lNq+6SkLe+Wp59BM= +github.com/pressly/goose/v3 v3.27.0/go.mod h1:3ZBeCXqzkgIRvrEMDkYh1guvtoJTU5oMMuDdkutoM78= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= +github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -60,29 +92,68 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= -google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= -google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/lumberjack.v2 v2.0.0 h1:IDj6hi8KbNiPQ5VaYNFZ7dBJLF5LFeKvsFrWHjA5aq4= +gopkg.in/lumberjack.v2 v2.0.0/go.mod h1:bp5nQ2kK/lLQSmTk29azj9+JB6bWci56xFn/lvd5GLI= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +modernc.org/libc v1.68.0 h1:PJ5ikFOV5pwpW+VqCK1hKJuEWsonkIJhhIXyuF/91pQ= +modernc.org/libc v1.68.0/go.mod h1:NnKCYeoYgsEqnY3PgvNgAeaJnso968ygU8Z0DxjoEc0= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index cb5c81b..c6fa6ca 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1,24 +1,87 @@ package config import ( + "fmt" "os" "path/filepath" + "time" + + "gopkg.in/yaml.v3" + + appErrors "nex/backend/pkg/errors" ) +// Config 应用配置 +type Config struct { + Server ServerConfig `yaml:"server"` + Database DatabaseConfig `yaml:"database"` + Log LogConfig `yaml:"log"` +} + +// ServerConfig 服务器配置 +type ServerConfig struct { + Port int `yaml:"port"` + ReadTimeout time.Duration `yaml:"read_timeout"` + WriteTimeout time.Duration `yaml:"write_timeout"` +} + +// DatabaseConfig 数据库配置 +type DatabaseConfig struct { + Path string `yaml:"path"` + MaxIdleConns int `yaml:"max_idle_conns"` + MaxOpenConns int `yaml:"max_open_conns"` + ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"` +} + +// LogConfig 日志配置 +type LogConfig struct { + Level string `yaml:"level"` + Path string `yaml:"path"` + MaxSize int `yaml:"max_size"` + MaxBackups int `yaml:"max_backups"` + MaxAge int `yaml:"max_age"` + Compress bool `yaml:"compress"` +} + +// DefaultConfig returns default config values +func DefaultConfig() *Config { + // Use home dir for default paths + homeDir, _ := os.UserHomeDir() + nexDir := filepath.Join(homeDir, ".nex") + + return &Config{ + Server: ServerConfig{ + Port: 9826, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + }, + Database: DatabaseConfig{ + Path: filepath.Join(nexDir, "config.db"), + MaxIdleConns: 10, + MaxOpenConns: 100, + ConnMaxLifetime: 1 * time.Hour, + }, + Log: LogConfig{ + Level: "info", + Path: filepath.Join(nexDir, "log"), + MaxSize: 100, + MaxBackups: 10, + MaxAge: 30, + Compress: true, + }, + } +} + // GetConfigDir 获取配置目录路径(~/.nex/) func GetConfigDir() (string, error) { homeDir, err := os.UserHomeDir() if err != nil { return "", err } - configDir := filepath.Join(homeDir, ".nex") - - // 确保目录存在 if err := os.MkdirAll(configDir, 0755); err != nil { return "", err } - return configDir, nil } @@ -30,3 +93,79 @@ func GetDBPath() (string, error) { } return filepath.Join(configDir, "config.db"), nil } + +// GetConfigPath 获取配置文件路径 +func GetConfigPath() (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, "config.yaml"), nil +} + +// LoadConfig loads config from YAML file, creates default if not exists +func LoadConfig() (*Config, error) { + configPath, err := GetConfigPath() + if err != nil { + return nil, appErrors.Wrap(appErrors.ErrInternal, err) + } + + cfg := DefaultConfig() + + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + // Create default config file + if saveErr := SaveConfig(cfg); saveErr != nil { + return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败") + } + return cfg, nil + } + return nil, appErrors.Wrap(appErrors.ErrInternal, err) + } + + if err := yaml.Unmarshal(data, cfg); err != nil { + return nil, appErrors.Wrap(appErrors.ErrInternal, err) + } + + return cfg, nil +} + +// SaveConfig saves config to YAML file +func SaveConfig(cfg *Config) error { + configPath, err := GetConfigPath() + if err != nil { + return appErrors.Wrap(appErrors.ErrInternal, err) + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return appErrors.Wrap(appErrors.ErrInternal, err) + } + + // Ensure directory exists + dir := filepath.Dir(configPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return appErrors.Wrap(appErrors.ErrInternal, err) + } + + return os.WriteFile(configPath, data, 0644) +} + +// Validate validates the config +func (c *Config) Validate() error { + if c.Server.Port < 1 || c.Server.Port > 65535 { + return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port)) + } + + validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true} + if !validLevels[c.Log.Level] { + return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level)) + } + + if c.Database.Path == "" { + return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空") + } + + return nil +} diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go new file mode 100644 index 0000000..7ebeab0 --- /dev/null +++ b/backend/internal/config/config_test.go @@ -0,0 +1,176 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + require.NotNil(t, cfg) + + assert.Equal(t, 9826, cfg.Server.Port) + assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout) + assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout) + + assert.Equal(t, 10, cfg.Database.MaxIdleConns) + assert.Equal(t, 100, cfg.Database.MaxOpenConns) + assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime) + + assert.Equal(t, "info", cfg.Log.Level) + assert.Equal(t, 100, cfg.Log.MaxSize) + assert.Equal(t, 10, cfg.Log.MaxBackups) + assert.Equal(t, 30, cfg.Log.MaxAge) + assert.Equal(t, true, cfg.Log.Compress) +} + +func TestConfig_Validate(t *testing.T) { + tests := []struct { + name string + modify func(*Config) + wantErr bool + errMsg string + }{ + { + name: "默认配置有效", + modify: func(c *Config) {}, + wantErr: false, + }, + { + name: "端口号为0无效", + modify: func(c *Config) { c.Server.Port = 0 }, + wantErr: true, + errMsg: "无效的端口号", + }, + { + name: "端口号超出范围无效", + modify: func(c *Config) { c.Server.Port = 70000 }, + wantErr: true, + errMsg: "无效的端口号", + }, + { + name: "端口号为1有效", + modify: func(c *Config) { c.Server.Port = 1 }, + wantErr: false, + }, + { + name: "端口号为65535有效", + modify: func(c *Config) { c.Server.Port = 65535 }, + wantErr: false, + }, + { + name: "无效日志级别", + modify: func(c *Config) { c.Log.Level = "invalid" }, + wantErr: true, + errMsg: "无效的日志级别", + }, + { + name: "debug级别有效", + modify: func(c *Config) { c.Log.Level = "debug" }, + wantErr: false, + }, + { + name: "warn级别有效", + modify: func(c *Config) { c.Log.Level = "warn" }, + wantErr: false, + }, + { + name: "error级别有效", + modify: func(c *Config) { c.Log.Level = "error" }, + wantErr: false, + }, + { + name: "数据库路径为空无效", + modify: func(c *Config) { c.Database.Path = "" }, + wantErr: true, + errMsg: "数据库路径不能为空", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + tt.modify(cfg) + err := cfg.Validate() + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestGetConfigDir(t *testing.T) { + dir, err := GetConfigDir() + require.NoError(t, err) + assert.NotEmpty(t, dir) + assert.Contains(t, dir, ".nex") +} + +func TestGetDBPath(t *testing.T) { + path, err := GetDBPath() + require.NoError(t, err) + assert.NotEmpty(t, path) + assert.Contains(t, path, "config.db") +} + +func TestGetConfigPath(t *testing.T) { + path, err := GetConfigPath() + require.NoError(t, err) + assert.NotEmpty(t, path) + assert.Contains(t, path, "config.yaml") +} + +func TestSaveAndLoadConfig(t *testing.T) { + // 使用临时目录覆盖配置路径 + dir := t.TempDir() + + cfg := &Config{ + Server: ServerConfig{ + Port: 9999, + ReadTimeout: 10 * time.Second, + WriteTimeout: 20 * time.Second, + }, + Database: DatabaseConfig{ + Path: filepath.Join(dir, "test.db"), + MaxIdleConns: 5, + MaxOpenConns: 50, + ConnMaxLifetime: 30 * time.Minute, + }, + Log: LogConfig{ + Level: "debug", + Path: filepath.Join(dir, "log"), + MaxSize: 50, + MaxBackups: 5, + MaxAge: 7, + Compress: false, + }, + } + + // 保存配置 + configPath := filepath.Join(dir, "config.yaml") + data, err := yaml.Marshal(cfg) + require.NoError(t, err) + err = os.WriteFile(configPath, data, 0644) + require.NoError(t, err) + + // 加载配置 + data, err = os.ReadFile(configPath) + require.NoError(t, err) + loaded := &Config{} + err = yaml.Unmarshal(data, loaded) + require.NoError(t, err) + + assert.Equal(t, cfg.Server.Port, loaded.Server.Port) + assert.Equal(t, cfg.Log.Level, loaded.Log.Level) + assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns) + assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress) +} diff --git a/backend/internal/config/database.go b/backend/internal/config/database.go deleted file mode 100644 index b9067b9..0000000 --- a/backend/internal/config/database.go +++ /dev/null @@ -1,58 +0,0 @@ -package config - -import ( - "fmt" - "log" - - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" -) - -var db *gorm.DB - -// InitDB 初始化数据库连接并创建表 -func InitDB() error { - dbPath, err := GetDBPath() - if err != nil { - return fmt.Errorf("获取数据库路径失败: %w", err) - } - - // 打开数据库连接 - db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Info), - }) - if err != nil { - return fmt.Errorf("连接数据库失败: %w", err) - } - - // 启用 WAL 模式以提升并发性能 - if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { - log.Printf("警告: 启用 WAL 模式失败: %v", err) - } - - // 自动迁移表结构 - if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil { - return fmt.Errorf("创建表失败: %w", err) - } - - log.Printf("数据库初始化成功: %s", dbPath) - return nil -} - -// GetDB 获取数据库连接 -func GetDB() *gorm.DB { - return db -} - -// CloseDB 关闭数据库连接 -func CloseDB() error { - if db != nil { - sqlDB, err := db.DB() - if err != nil { - return err - } - return sqlDB.Close() - } - return nil -} diff --git a/backend/internal/config/model.go b/backend/internal/config/model.go deleted file mode 100644 index e77c105..0000000 --- a/backend/internal/config/model.go +++ /dev/null @@ -1,119 +0,0 @@ -package config - -import ( - "errors" - "time" - - "gorm.io/gorm" -) - -// CreateModel 创建模型 -func CreateModel(model *Model) error { - db := GetDB() - if db == nil { - return errors.New("数据库未初始化") - } - - // 验证供应商是否存在 - var provider Provider - err := db.First(&provider, "id = ?", model.ProviderID).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return errors.New("供应商不存在") - } - return err - } - - model.CreatedAt = time.Now() - - return db.Create(model).Error -} - -// GetModel 获取模型 -func GetModel(id string) (*Model, error) { - db := GetDB() - if db == nil { - return nil, errors.New("数据库未初始化") - } - - var model Model - err := db.First(&model, "id = ?", id).Error - if err != nil { - return nil, err - } - - return &model, nil -} - -// ListModels 列出模型 -func ListModels(providerID string) ([]Model, error) { - db := GetDB() - if db == nil { - return nil, errors.New("数据库未初始化") - } - - var models []Model - var err error - - if providerID != "" { - err = db.Where("provider_id = ?", providerID).Find(&models).Error - } else { - err = db.Find(&models).Error - } - - if err != nil { - return nil, err - } - - return models, nil -} - -// UpdateModel 更新模型 -func UpdateModel(id string, updates map[string]interface{}) error { - db := GetDB() - if db == nil { - return errors.New("数据库未初始化") - } - - // 如果更新了 provider_id,验证新供应商是否存在 - if providerID, ok := updates["provider_id"].(string); ok { - var provider Provider - err := db.First(&provider, "id = ?", providerID).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return errors.New("供应商不存在") - } - return err - } - } - - result := db.Model(&Model{}).Where("id = ?", id).Updates(updates) - if result.Error != nil { - return result.Error - } - - if result.RowsAffected == 0 { - return gorm.ErrRecordNotFound - } - - return nil -} - -// DeleteModel 删除模型 -func DeleteModel(id string) error { - db := GetDB() - if db == nil { - return errors.New("数据库未初始化") - } - - result := db.Delete(&Model{}, "id = ?", id) - if result.Error != nil { - return result.Error - } - - if result.RowsAffected == 0 { - return gorm.ErrRecordNotFound - } - - return nil -} diff --git a/backend/internal/config/provider.go b/backend/internal/config/provider.go deleted file mode 100644 index b24175f..0000000 --- a/backend/internal/config/provider.go +++ /dev/null @@ -1,102 +0,0 @@ -package config - -import ( - "errors" - "time" - - "gorm.io/gorm" -) - -// CreateProvider 创建供应商 -func CreateProvider(provider *Provider) error { - db := GetDB() - if db == nil { - return errors.New("数据库未初始化") - } - - provider.CreatedAt = time.Now() - provider.UpdatedAt = time.Now() - - return db.Create(provider).Error -} - -// GetProvider 获取供应商 -func GetProvider(id string, maskKey bool) (*Provider, error) { - db := GetDB() - if db == nil { - return nil, errors.New("数据库未初始化") - } - - var provider Provider - err := db.First(&provider, "id = ?", id).Error - if err != nil { - return nil, err - } - - if maskKey { - provider.MaskAPIKey() - } - - return &provider, nil -} - -// ListProviders 列出所有供应商 -func ListProviders() ([]Provider, error) { - db := GetDB() - if db == nil { - return nil, errors.New("数据库未初始化") - } - - var providers []Provider - err := db.Find(&providers).Error - if err != nil { - return nil, err - } - - // 掩码所有 API Key - for i := range providers { - providers[i].MaskAPIKey() - } - - return providers, nil -} - -// UpdateProvider 更新供应商 -func UpdateProvider(id string, updates map[string]interface{}) error { - db := GetDB() - if db == nil { - return errors.New("数据库未初始化") - } - - updates["updated_at"] = time.Now() - - result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates) - if result.Error != nil { - return result.Error - } - - if result.RowsAffected == 0 { - return gorm.ErrRecordNotFound - } - - return nil -} - -// DeleteProvider 删除供应商 -func DeleteProvider(id string) error { - db := GetDB() - if db == nil { - return errors.New("数据库未初始化") - } - - result := db.Delete(&Provider{}, "id = ?", id) - if result.Error != nil { - return result.Error - } - - if result.RowsAffected == 0 { - return gorm.ErrRecordNotFound - } - - return nil -} diff --git a/backend/internal/config/stats.go b/backend/internal/config/stats.go deleted file mode 100644 index 272b44b..0000000 --- a/backend/internal/config/stats.go +++ /dev/null @@ -1,79 +0,0 @@ -package config - -import ( - "errors" - "time" - - "gorm.io/gorm" -) - -// RecordRequest 记录请求统计 -func RecordRequest(providerID, modelName string) error { - db := GetDB() - if db == nil { - return errors.New("数据库未初始化") - } - - today := time.Now().Format("2006-01-02") - todayTime, _ := time.Parse("2006-01-02", today) - - // 使用事务确保并发安全 - return db.Transaction(func(tx *gorm.DB) error { - var stats UsageStats - - // 查找或创建统计记录 - err := tx.Where("provider_id = ? AND model_name = ? AND date = ?", - providerID, modelName, todayTime). - First(&stats).Error - - if errors.Is(err, gorm.ErrRecordNotFound) { - // 创建新记录 - stats = UsageStats{ - ProviderID: providerID, - ModelName: modelName, - RequestCount: 1, - Date: todayTime, - } - return tx.Create(&stats).Error - } else if err != nil { - return err - } - - // 更新计数 - return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error - }) -} - -// GetStats 查询统计 -func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) { - db := GetDB() - if db == nil { - return nil, errors.New("数据库未初始化") - } - - var stats []UsageStats - query := db.Model(&UsageStats{}) - - if providerID != "" { - query = query.Where("provider_id = ?", providerID) - } - - if modelName != "" { - query = query.Where("model_name = ?", modelName) - } - - if startDate != nil { - query = query.Where("date >= ?", startDate) - } - - if endDate != nil { - query = query.Where("date <= ?", endDate) - } - - err := query.Order("date DESC").Find(&stats).Error - if err != nil { - return nil, err - } - - return stats, nil -} diff --git a/backend/internal/domain/model.go b/backend/internal/domain/model.go new file mode 100644 index 0000000..30b4b49 --- /dev/null +++ b/backend/internal/domain/model.go @@ -0,0 +1,12 @@ +package domain + +import "time" + +// Model 模型领域模型 +type Model struct { + ID string `json:"id"` + ProviderID string `json:"provider_id"` + ModelName string `json:"model_name"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/backend/internal/domain/provider.go b/backend/internal/domain/provider.go new file mode 100644 index 0000000..199d18e --- /dev/null +++ b/backend/internal/domain/provider.go @@ -0,0 +1,23 @@ +package domain + +import "time" + +// Provider 供应商领域模型 +type Provider struct { + ID string `json:"id"` + Name string `json:"name"` + APIKey string `json:"api_key"` + BaseURL string `json:"base_url"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符) +func (p *Provider) MaskAPIKey() { + if len(p.APIKey) > 4 { + p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:] + } else { + p.APIKey = "***" + } +} diff --git a/backend/internal/domain/route.go b/backend/internal/domain/route.go new file mode 100644 index 0000000..91f6466 --- /dev/null +++ b/backend/internal/domain/route.go @@ -0,0 +1,7 @@ +package domain + +// RouteResult 路由结果 +type RouteResult struct { + Provider *Provider + Model *Model +} diff --git a/backend/internal/domain/stats.go b/backend/internal/domain/stats.go new file mode 100644 index 0000000..188dc11 --- /dev/null +++ b/backend/internal/domain/stats.go @@ -0,0 +1,12 @@ +package domain + +import "time" + +// UsageStats 用量统计领域模型 +type UsageStats struct { + ID uint `json:"id"` + ProviderID string `json:"provider_id"` + ModelName string `json:"model_name"` + RequestCount int `json:"request_count"` + Date time.Time `json:"date"` +} diff --git a/backend/internal/handler/anthropic_handler.go b/backend/internal/handler/anthropic_handler.go index 931fe02..13e5aef 100644 --- a/backend/internal/handler/anthropic_handler.go +++ b/backend/internal/handler/anthropic_handler.go @@ -7,30 +7,33 @@ import ( "github.com/gin-gonic/gin" - "nex/backend/internal/config" + appErrors "nex/backend/pkg/errors" + + "nex/backend/internal/domain" "nex/backend/internal/protocol/anthropic" "nex/backend/internal/protocol/openai" "nex/backend/internal/provider" - "nex/backend/internal/router" + "nex/backend/internal/service" ) // AnthropicHandler Anthropic 协议处理器 type AnthropicHandler struct { - client *provider.Client - router *router.Router + client provider.ProviderClient + routingService service.RoutingService + statsService service.StatsService } // NewAnthropicHandler 创建 Anthropic 处理器 -func NewAnthropicHandler() *AnthropicHandler { +func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler { return &AnthropicHandler{ - client: provider.NewClient(), - router: router.NewRouter(), + client: client, + routingService: routingService, + statsService: statsService, } } // HandleMessages 处理 Messages 请求 func (h *AnthropicHandler) HandleMessages(c *gin.Context) { - // 解析 Anthropic 请求 var req anthropic.MessagesRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ @@ -43,7 +46,19 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) { return } - // 检查多模态内容 + // 请求验证 + if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil { + errMsg := formatValidationErrors(validationErrors) + c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.ErrorDetail{ + Type: "invalid_request_error", + Message: errMsg, + }, + }) + return + } + if err := h.checkMultimodalContent(&req); err != nil { c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ Type: "error", @@ -55,7 +70,6 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) { return } - // 转换为 OpenAI 请求 openaiReq, err := anthropic.ConvertRequest(&req) if err != nil { c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{ @@ -68,14 +82,12 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) { return } - // 路由到供应商 - routeResult, err := h.router.Route(openaiReq.Model) + routeResult, err := h.routingService.Route(openaiReq.Model) if err != nil { h.handleError(c, err) return } - // 根据是否流式选择处理方式 if req.Stream { h.handleStreamRequest(c, openaiReq, routeResult) } else { @@ -83,9 +95,7 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) { } } -// handleNonStreamRequest 处理非流式请求 -func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) { - // 发送请求到供应商 +func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ @@ -98,7 +108,6 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope return } - // 转换为 Anthropic 响应 anthropicResp, err := anthropic.ConvertResponse(openaiResp) if err != nil { c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ @@ -111,18 +120,14 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope return } - // 记录统计 go func() { - _ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model) + _ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model) }() - // 返回响应 c.JSON(http.StatusOK, anthropicResp) } -// handleStreamRequest 处理流式请求 -func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) { - // 发送流式请求到供应商 +func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ @@ -135,24 +140,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai return } - // 设置 SSE 响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") - // 创建流写入器 writer := bufio.NewWriter(c.Writer) - // 创建流式转换器 converter := anthropic.NewStreamConverter( fmt.Sprintf("msg_%s", routeResult.Provider.ID), openaiReq.Model, ) - // 流式转发事件 for event := range eventChan { if event.Error != nil { - fmt.Printf("流错误: %v\n", event.Error) break } @@ -160,25 +160,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai break } - // 解析 OpenAI 流块 chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data) if err != nil { - fmt.Printf("解析流块失败: %v\n", err) continue } - // 转换为 Anthropic 事件 anthropicEvents, err := converter.ConvertChunk(chunk) if err != nil { - fmt.Printf("转换事件失败: %v\n", err) continue } - // 写入事件 for _, ae := range anthropicEvents { eventStr, err := anthropic.SerializeEvent(ae) if err != nil { - fmt.Printf("序列化事件失败: %v\n", err) continue } writer.WriteString(eventStr) @@ -186,13 +180,11 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai } } - // 记录统计 go func() { - _ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model) + _ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model) }() } -// checkMultimodalContent 检查多模态内容 func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error { for _, msg := range req.Messages { for _, block := range msg.Content { @@ -204,40 +196,22 @@ func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest return nil } -// handleError 处理路由错误 func (h *AnthropicHandler) handleError(c *gin.Context, err error) { - switch err { - case router.ErrModelNotFound: - c.JSON(http.StatusNotFound, anthropic.ErrorResponse{ + if appErr, ok := appErrors.AsAppError(err); ok { + c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{ Type: "error", Error: anthropic.ErrorDetail{ Type: "not_found_error", - Message: "模型未找到", - }, - }) - case router.ErrModelDisabled: - c.JSON(http.StatusNotFound, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "not_found_error", - Message: "模型已禁用", - }, - }) - case router.ErrProviderDisabled: - c.JSON(http.StatusNotFound, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "not_found_error", - Message: "供应商已禁用", - }, - }) - default: - c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ - Type: "error", - Error: anthropic.ErrorDetail{ - Type: "internal_error", - Message: "内部错误: " + err.Error(), + Message: appErr.Message, }, }) + return } + c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.ErrorDetail{ + Type: "internal_error", + Message: "内部错误: " + err.Error(), + }, + }) } diff --git a/backend/internal/handler/handler_test.go b/backend/internal/handler/handler_test.go new file mode 100644 index 0000000..8de985b --- /dev/null +++ b/backend/internal/handler/handler_test.go @@ -0,0 +1,290 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "nex/backend/internal/domain" + "nex/backend/internal/protocol/openai" + "nex/backend/internal/provider" + appErrors "nex/backend/pkg/errors" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// ============ Mock 实现 ============ + +type mockRoutingService struct { + result *domain.RouteResult + err error +} + +func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) { + return m.result, m.err +} + +type mockStatsService struct { + err error + stats []domain.UsageStats + aggrResult []map[string]interface{} +} + +func (m *mockStatsService) Record(providerID, modelName string) error { + return m.err +} +func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { + return m.stats, nil +} +func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} { + return m.aggrResult +} + +type mockProviderService struct { + provider *domain.Provider + providers []domain.Provider + err error +} + +func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err } +func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) { + return m.provider, m.err +} +func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err } +func (m *mockProviderService) Update(id string, updates map[string]interface{}) error { + return m.err +} +func (m *mockProviderService) Delete(id string) error { return m.err } + +type mockModelService struct { + model *domain.Model + models []domain.Model + err error +} + +func (m *mockModelService) Create(model *domain.Model) error { return m.err } +func (m *mockModelService) Get(id string) (*domain.Model, error) { + return m.model, m.err +} +func (m *mockModelService) List(providerID string) ([]domain.Model, error) { + return m.models, m.err +} +func (m *mockModelService) Update(id string, updates map[string]interface{}) error { + return m.err +} +func (m *mockModelService) Delete(id string) error { return m.err } + +type mockProviderClient struct { + resp *openai.ChatCompletionResponse + eventChan chan provider.StreamEvent + err error +} + +func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) { + return m.resp, m.err +} +func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) { + return m.eventChan, m.err +} + +// ============ OpenAI Handler 测试 ============ + +func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) { + h := NewOpenAIHandler(nil, nil, nil) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid"))) + + h.HandleChatCompletions(c) + assert.Equal(t, 400, w.Code) +} + +func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) { + h := NewOpenAIHandler(nil, nil, nil) + + // 缺少 model 字段 + body, _ := json.Marshal(map[string]interface{}{ + "messages": []map[string]string{{"role": "user", "content": "hi"}}, + }) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.HandleChatCompletions(c) + assert.Equal(t, 400, w.Code) +} + +func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) { + routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound} + h := NewOpenAIHandler(nil, routingSvc, nil) + + body, _ := json.Marshal(map[string]interface{}{ + "model": "nonexistent", + "messages": []map[string]string{{"role": "user", "content": "hi"}}, + }) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.HandleChatCompletions(c) + assert.Equal(t, 404, w.Code) +} + +// ============ Provider Handler 测试 ============ + +func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { + h := NewProviderHandler(&mockProviderService{}) + + body, _ := json.Marshal(map[string]string{"id": "p1"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.CreateProvider(c) + assert.Equal(t, 400, w.Code) +} + +func TestProviderHandler_ListProviders(t *testing.T) { + h := NewProviderHandler(&mockProviderService{ + providers: []domain.Provider{ + {ID: "p1", Name: "P1"}, + {ID: "p2", Name: "P2"}, + }, + }) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/providers", nil) + + h.ListProviders(c) + assert.Equal(t, 200, w.Code) + + var result []domain.Provider + json.Unmarshal(w.Body.Bytes(), &result) + assert.Len(t, result, 2) +} + +func TestProviderHandler_GetProvider(t *testing.T) { + h := NewProviderHandler(&mockProviderService{ + provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"}, + }) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "p1"}} + c.Request = httptest.NewRequest("GET", "/api/providers/p1", nil) + + h.GetProvider(c) + assert.Equal(t, 200, w.Code) +} + +// ============ Model Handler 测试 ============ + +func TestModelHandler_CreateModel_MissingFields(t *testing.T) { + h := NewModelHandler(&mockModelService{}) + + body, _ := json.Marshal(map[string]string{"id": "m1"}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.CreateModel(c) + assert.Equal(t, 400, w.Code) +} + +func TestModelHandler_ListModels(t *testing.T) { + h := NewModelHandler(&mockModelService{ + models: []domain.Model{ + {ID: "m1", ModelName: "gpt-4"}, + {ID: "m2", ModelName: "gpt-3.5"}, + }, + }) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/models", nil) + + h.ListModels(c) + assert.Equal(t, 200, w.Code) +} + +// ============ Stats Handler 测试 ============ + +func TestStatsHandler_GetStats(t *testing.T) { + h := NewStatsHandler(&mockStatsService{ + stats: []domain.UsageStats{ + {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, + }, + }) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/stats", nil) + + h.GetStats(c) + assert.Equal(t, 200, w.Code) +} + +func TestStatsHandler_GetStats_InvalidDate(t *testing.T) { + h := NewStatsHandler(&mockStatsService{}) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/stats?start_date=invalid", nil) + + h.GetStats(c) + assert.Equal(t, 400, w.Code) +} + +func TestStatsHandler_AggregateStats(t *testing.T) { + h := NewStatsHandler(&mockStatsService{ + stats: []domain.UsageStats{ + {ProviderID: "p1", RequestCount: 10}, + }, + aggrResult: []map[string]interface{}{ + {"provider_id": "p1", "request_count": 10}, + }, + }) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil) + + h.AggregateStats(c) + assert.Equal(t, 200, w.Code) +} + +// ============ writeError 测试 ============ + +func TestWriteError(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/", nil) + + writeError(c, appErrors.ErrModelNotFound) + assert.Equal(t, 404, w.Code) +} + +func TestFormatValidationErrors(t *testing.T) { + errs := map[string]string{ + "model": "模型名称不能为空", + "messages": "消息列表不能为空", + } + result := formatValidationErrors(errs) + require.Contains(t, result, "请求验证失败") + require.Contains(t, result, "model") + require.Contains(t, result, "messages") +} diff --git a/backend/internal/handler/middleware/cors.go b/backend/internal/handler/middleware/cors.go new file mode 100644 index 0000000..4a436f1 --- /dev/null +++ b/backend/internal/handler/middleware/cors.go @@ -0,0 +1,21 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" +) + +// CORS 跨域中间件 +func CORS() gin.HandlerFunc { + return func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + } +} diff --git a/backend/internal/handler/middleware/logging.go b/backend/internal/handler/middleware/logging.go new file mode 100644 index 0000000..fd17da6 --- /dev/null +++ b/backend/internal/handler/middleware/logging.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// Logging 日志中间件 +func Logging(logger *zap.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + start := time.Now() + path := c.Request.URL.Path + query := c.Request.URL.RawQuery + + requestID, _ := c.Get(RequestIDKey) + logger.Info("请求开始", + zap.String("method", c.Request.Method), + zap.String("path", path), + zap.String("query", query), + zap.String("client_ip", c.ClientIP()), + zap.Any("request_id", requestID), + ) + + c.Next() + + latency := time.Since(start) + statusCode := c.Writer.Status() + + logger.Info("请求结束", + zap.Int("status", statusCode), + zap.String("method", c.Request.Method), + zap.String("path", path), + zap.Duration("latency", latency), + zap.Int("body_size", c.Writer.Size()), + zap.Any("request_id", requestID), + ) + } +} diff --git a/backend/internal/handler/middleware/middleware_test.go b/backend/internal/handler/middleware/middleware_test.go new file mode 100644 index 0000000..74bbace --- /dev/null +++ b/backend/internal/handler/middleware/middleware_test.go @@ -0,0 +1,130 @@ +package middleware + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func TestRequestID_GeneratesUUID(t *testing.T) { + r := gin.New() + r.Use(RequestID()) + r.GET("/test", func(c *gin.Context) { + id, exists := c.Get(RequestIDKey) + assert.True(t, exists) + assert.NotEmpty(t, id) + c.Status(200) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.NotEmpty(t, w.Header().Get("X-Request-ID")) +} + +func TestRequestID_UsesExistingHeader(t *testing.T) { + r := gin.New() + r.Use(RequestID()) + r.GET("/test", func(c *gin.Context) { + id, _ := c.Get(RequestIDKey) + assert.Equal(t, "existing-id-123", id) + c.Status(200) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Request-ID", "existing-id-123") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "existing-id-123", w.Header().Get("X-Request-ID")) +} + +func TestLogging(t *testing.T) { + logger := zap.NewNop() + + r := gin.New() + r.Use(Logging(logger)) + r.GET("/test", func(c *gin.Context) { + c.Status(200) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/test?key=value", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) +} + +func TestRecovery_NoPanic(t *testing.T) { + logger := zap.NewNop() + + r := gin.New() + r.Use(Recovery(logger)) + r.GET("/test", func(c *gin.Context) { + c.Status(200) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) +} + +func TestRecovery_WithPanic(t *testing.T) { + logger := zap.NewNop() + + r := gin.New() + r.Use(Recovery(logger)) + r.GET("/test", func(c *gin.Context) { + panic("test panic") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 500, w.Code) +} + +func TestCORS_NormalRequest(t *testing.T) { + r := gin.New() + r.Use(CORS()) + r.GET("/test", func(c *gin.Context) { + c.Status(200) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) + assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET") + assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "POST") +} + +func TestCORS_PreflightRequest(t *testing.T) { + r := gin.New() + r.Use(CORS()) + r.OPTIONS("/test", func(c *gin.Context) { + c.Status(200) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("OPTIONS", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 204, w.Code) + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) +} diff --git a/backend/internal/handler/middleware/recovery.go b/backend/internal/handler/middleware/recovery.go new file mode 100644 index 0000000..c03c275 --- /dev/null +++ b/backend/internal/handler/middleware/recovery.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// Recovery 错误恢复中间件 +func Recovery(logger *zap.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + requestID, _ := c.Get(RequestIDKey) + logger.Error("panic recovered", + zap.Any("error", err), + zap.Any("request_id", requestID), + zap.String("path", c.Request.URL.Path), + zap.Stack("stack"), + ) + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ + "error": "内部错误", + }) + } + }() + c.Next() + } +} diff --git a/backend/internal/handler/middleware/request_id.go b/backend/internal/handler/middleware/request_id.go new file mode 100644 index 0000000..607f100 --- /dev/null +++ b/backend/internal/handler/middleware/request_id.go @@ -0,0 +1,21 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +const RequestIDKey = "request_id" + +// RequestID 请求 ID 中间件 +func RequestID() gin.HandlerFunc { + return func(c *gin.Context) { + requestID := c.GetHeader("X-Request-ID") + if requestID == "" { + requestID = uuid.New().String() + } + c.Set(RequestIDKey, requestID) + c.Header("X-Request-ID", requestID) + c.Next() + } +} diff --git a/backend/internal/handler/model_handler.go b/backend/internal/handler/model_handler.go index 5e52ac5..9135c66 100644 --- a/backend/internal/handler/model_handler.go +++ b/backend/internal/handler/model_handler.go @@ -6,15 +6,20 @@ import ( "github.com/gin-gonic/gin" "gorm.io/gorm" - "nex/backend/internal/config" + appErrors "nex/backend/pkg/errors" + + "nex/backend/internal/domain" + "nex/backend/internal/service" ) // ModelHandler 模型管理处理器 -type ModelHandler struct{} +type ModelHandler struct { + modelService service.ModelService +} // NewModelHandler 创建模型处理器 -func NewModelHandler() *ModelHandler { - return &ModelHandler{} +func NewModelHandler(modelService service.ModelService) *ModelHandler { + return &ModelHandler{modelService: modelService} } // CreateModel 创建模型 @@ -32,26 +37,21 @@ func (h *ModelHandler) CreateModel(c *gin.Context) { return } - // 创建模型对象 - model := &config.Model{ + model := &domain.Model{ ID: req.ID, ProviderID: req.ProviderID, ModelName: req.ModelName, - Enabled: true, // 默认启用 } - // 保存到数据库 - err := config.CreateModel(model) + err := h.modelService.Create(model) if err != nil { - if err.Error() == "供应商不存在" { + if err == appErrors.ErrProviderNotFound { c.JSON(http.StatusBadRequest, gin.H{ "error": "供应商不存在", }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "创建模型失败: " + err.Error(), - }) + writeError(c, err) return } @@ -62,11 +62,9 @@ func (h *ModelHandler) CreateModel(c *gin.Context) { func (h *ModelHandler) ListModels(c *gin.Context) { providerID := c.Query("provider_id") - models, err := config.ListModels(providerID) + models, err := h.modelService.List(providerID) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "查询模型失败: " + err.Error(), - }) + writeError(c, err) return } @@ -77,7 +75,7 @@ func (h *ModelHandler) ListModels(c *gin.Context) { func (h *ModelHandler) GetModel(c *gin.Context) { id := c.Param("id") - model, err := config.GetModel(id) + model, err := h.modelService.Get(id) if err != nil { if err == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, gin.H{ @@ -85,9 +83,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) { }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "查询模型失败: " + err.Error(), - }) + writeError(c, err) return } @@ -106,8 +102,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) { return } - // 更新模型 - err := config.UpdateModel(id, req) + err := h.modelService.Update(id, req) if err != nil { if err == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, gin.H{ @@ -115,24 +110,19 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) { }) return } - if err.Error() == "供应商不存在" { + if err == appErrors.ErrProviderNotFound { c.JSON(http.StatusBadRequest, gin.H{ "error": "供应商不存在", }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "更新模型失败: " + err.Error(), - }) + writeError(c, err) return } - // 返回更新后的模型 - model, err := config.GetModel(id) + model, err := h.modelService.Get(id) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "查询更新后的模型失败: " + err.Error(), - }) + writeError(c, err) return } @@ -143,7 +133,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) { func (h *ModelHandler) DeleteModel(c *gin.Context) { id := c.Param("id") - err := config.DeleteModel(id) + err := h.modelService.Delete(id) if err != nil { if err == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, gin.H{ @@ -151,9 +141,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) { }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "删除模型失败: " + err.Error(), - }) + writeError(c, err) return } diff --git a/backend/internal/handler/openai_handler.go b/backend/internal/handler/openai_handler.go index dec3c18..a15fc0a 100644 --- a/backend/internal/handler/openai_handler.go +++ b/backend/internal/handler/openai_handler.go @@ -4,32 +4,36 @@ import ( "bufio" "fmt" "net/http" + "strings" "github.com/gin-gonic/gin" - "nex/backend/internal/config" + appErrors "nex/backend/pkg/errors" + + "nex/backend/internal/domain" "nex/backend/internal/protocol/openai" "nex/backend/internal/provider" - "nex/backend/internal/router" + "nex/backend/internal/service" ) // OpenAIHandler OpenAI 协议处理器 type OpenAIHandler struct { - client *provider.Client - router *router.Router + client provider.ProviderClient + routingService service.RoutingService + statsService service.StatsService } // NewOpenAIHandler 创建 OpenAI 处理器 -func NewOpenAIHandler() *OpenAIHandler { +func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler { return &OpenAIHandler{ - client: provider.NewClient(), - router: router.NewRouter(), + client: client, + routingService: routingService, + statsService: statsService, } } // HandleChatCompletions 处理 Chat Completions 请求 func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) { - // 解析请求 var req openai.ChatCompletionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, openai.ErrorResponse{ @@ -41,14 +45,23 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) { return } - // 路由到供应商 - routeResult, err := h.router.Route(req.Model) + // 请求验证 + if validationErrors := openai.ValidateRequest(&req); validationErrors != nil { + c.JSON(http.StatusBadRequest, openai.ErrorResponse{ + Error: openai.ErrorDetail{ + Message: formatValidationErrors(validationErrors), + Type: "invalid_request_error", + }, + }) + return + } + + routeResult, err := h.routingService.Route(req.Model) if err != nil { h.handleError(c, err) return } - // 根据是否流式选择处理方式 if req.Stream { h.handleStreamRequest(c, &req, routeResult) } else { @@ -56,9 +69,7 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) { } } -// handleNonStreamRequest 处理非流式请求 -func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) { - // 发送请求到供应商 +func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ @@ -70,18 +81,14 @@ func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatC return } - // 记录统计 go func() { - _ = config.RecordRequest(routeResult.Provider.ID, req.Model) + _ = h.statsService.Record(routeResult.Provider.ID, req.Model) }() - // 返回响应 c.JSON(http.StatusOK, resp) } -// handleStreamRequest 处理流式请求 -func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) { - // 发送流式请求到供应商 +func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) { eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL) if err != nil { c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ @@ -93,75 +100,58 @@ func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatComp return } - // 设置 SSE 响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") - // 创建流写入器 writer := bufio.NewWriter(c.Writer) - // 流式转发事件 for event := range eventChan { if event.Error != nil { - // 流错误,记录日志 - fmt.Printf("流错误: %v\n", event.Error) break } if event.Done { - // 流结束 writer.WriteString("data: [DONE]\n\n") writer.Flush() break } - // 写入事件数据 writer.WriteString("data: ") writer.Write(event.Data) writer.WriteString("\n\n") writer.Flush() } - // 记录统计 go func() { - _ = config.RecordRequest(routeResult.Provider.ID, req.Model) + _ = h.statsService.Record(routeResult.Provider.ID, req.Model) }() } -// handleError 处理路由错误 func (h *OpenAIHandler) handleError(c *gin.Context, err error) { - switch err { - case router.ErrModelNotFound: - c.JSON(http.StatusNotFound, openai.ErrorResponse{ + if appErr, ok := appErrors.AsAppError(err); ok { + c.JSON(appErr.HTTPStatus, openai.ErrorResponse{ Error: openai.ErrorDetail{ - Message: "模型未找到", + Message: appErr.Message, Type: "invalid_request_error", - Code: "model_not_found", - }, - }) - case router.ErrModelDisabled: - c.JSON(http.StatusNotFound, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: "模型已禁用", - Type: "invalid_request_error", - Code: "model_disabled", - }, - }) - case router.ErrProviderDisabled: - c.JSON(http.StatusNotFound, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: "供应商已禁用", - Type: "invalid_request_error", - Code: "provider_disabled", - }, - }) - default: - c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ - Error: openai.ErrorDetail{ - Message: "内部错误: " + err.Error(), - Type: "internal_error", + Code: appErr.Code, }, }) + return } + c.JSON(http.StatusInternalServerError, openai.ErrorResponse{ + Error: openai.ErrorDetail{ + Message: "内部错误: " + err.Error(), + Type: "internal_error", + }, + }) +} + +// formatValidationErrors 将验证错误 map 格式化为字符串 +func formatValidationErrors(errors map[string]string) string { + parts := make([]string, 0, len(errors)) + for field, msg := range errors { + parts = append(parts, fmt.Sprintf("%s: %s", field, msg)) + } + return "请求验证失败: " + strings.Join(parts, "; ") } diff --git a/backend/internal/handler/provider_handler.go b/backend/internal/handler/provider_handler.go index 4016419..313252b 100644 --- a/backend/internal/handler/provider_handler.go +++ b/backend/internal/handler/provider_handler.go @@ -2,19 +2,25 @@ package handler import ( "net/http" + "strings" "github.com/gin-gonic/gin" "gorm.io/gorm" - "nex/backend/internal/config" + appErrors "nex/backend/pkg/errors" + + "nex/backend/internal/domain" + "nex/backend/internal/service" ) // ProviderHandler 供应商管理处理器 -type ProviderHandler struct{} +type ProviderHandler struct { + providerService service.ProviderService +} // NewProviderHandler 创建供应商处理器 -func NewProviderHandler() *ProviderHandler { - return &ProviderHandler{} +func NewProviderHandler(providerService service.ProviderService) *ProviderHandler { + return &ProviderHandler{providerService: providerService} } // CreateProvider 创建供应商 @@ -33,43 +39,34 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) { return } - // 创建供应商对象 - provider := &config.Provider{ + provider := &domain.Provider{ ID: req.ID, Name: req.Name, APIKey: req.APIKey, BaseURL: req.BaseURL, - Enabled: true, // 默认启用 } - // 保存到数据库 - err := config.CreateProvider(provider) + err := h.providerService.Create(provider) if err != nil { - // 检查是否是唯一约束错误(ID 重复) - if err.Error() == "UNIQUE constraint failed: providers.id" { + if strings.Contains(err.Error(), "UNIQUE constraint failed") { c.JSON(http.StatusConflict, gin.H{ "error": "供应商 ID 已存在", }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "创建供应商失败: " + err.Error(), - }) + writeError(c, err) return } - // 掩码 API Key 后返回 provider.MaskAPIKey() c.JSON(http.StatusCreated, provider) } // ListProviders 列出所有供应商 func (h *ProviderHandler) ListProviders(c *gin.Context) { - providers, err := config.ListProviders() + providers, err := h.providerService.List() if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "查询供应商失败: " + err.Error(), - }) + writeError(c, err) return } @@ -80,7 +77,7 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) { func (h *ProviderHandler) GetProvider(c *gin.Context) { id := c.Param("id") - provider, err := config.GetProvider(id, true) // 掩码 API Key + provider, err := h.providerService.Get(id, true) if err != nil { if err == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, gin.H{ @@ -88,9 +85,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) { }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "查询供应商失败: " + err.Error(), - }) + writeError(c, err) return } @@ -109,8 +104,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) { return } - // 更新供应商 - err := config.UpdateProvider(id, req) + err := h.providerService.Update(id, req) if err != nil { if err == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, gin.H{ @@ -118,18 +112,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) { }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "更新供应商失败: " + err.Error(), - }) + writeError(c, err) return } - // 返回更新后的供应商 - provider, err := config.GetProvider(id, true) + provider, err := h.providerService.Get(id, true) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "查询更新后的供应商失败: " + err.Error(), - }) + writeError(c, err) return } @@ -140,8 +129,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) { func (h *ProviderHandler) DeleteProvider(c *gin.Context) { id := c.Param("id") - // 删除供应商(级联删除模型) - err := config.DeleteProvider(id) + err := h.providerService.Delete(id) if err != nil { if err == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, gin.H{ @@ -149,19 +137,23 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) { }) return } - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "删除供应商失败: " + err.Error(), - }) + writeError(c, err) return } - // 删除关联的模型 - models, _ := config.ListModels("") - for _, model := range models { - if model.ProviderID == id { - _ = config.DeleteModel(model.ID) - } - } - c.Status(http.StatusNoContent) } + +// writeError 统一错误响应处理 +func writeError(c *gin.Context, err error) { + if appErr, ok := appErrors.AsAppError(err); ok { + c.JSON(appErr.HTTPStatus, gin.H{ + "error": appErr.Message, + "code": appErr.Code, + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{ + "error": err.Error(), + }) +} diff --git a/backend/internal/handler/stats_handler.go b/backend/internal/handler/stats_handler.go index 9042ac0..40a8251 100644 --- a/backend/internal/handler/stats_handler.go +++ b/backend/internal/handler/stats_handler.go @@ -6,20 +6,21 @@ import ( "github.com/gin-gonic/gin" - "nex/backend/internal/config" + "nex/backend/internal/service" ) // StatsHandler 统计处理器 -type StatsHandler struct{} +type StatsHandler struct { + statsService service.StatsService +} // NewStatsHandler 创建统计处理器 -func NewStatsHandler() *StatsHandler { - return &StatsHandler{} +func NewStatsHandler(statsService service.StatsService) *StatsHandler { + return &StatsHandler{statsService: statsService} } // GetStats 查询统计 func (h *StatsHandler) GetStats(c *gin.Context) { - // 解析查询参数 providerID := c.Query("provider_id") modelName := c.Query("model_name") startDateStr := c.Query("start_date") @@ -27,7 +28,6 @@ func (h *StatsHandler) GetStats(c *gin.Context) { var startDate, endDate *time.Time - // 解析日期 if startDateStr != "" { t, err := time.Parse("2006-01-02", startDateStr) if err != nil { @@ -50,8 +50,7 @@ func (h *StatsHandler) GetStats(c *gin.Context) { endDate = &t } - // 查询统计 - stats, err := config.GetStats(providerID, modelName, startDate, endDate) + stats, err := h.statsService.Get(providerID, modelName, startDate, endDate) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "查询统计失败: " + err.Error(), @@ -64,16 +63,14 @@ func (h *StatsHandler) GetStats(c *gin.Context) { // AggregateStats 聚合统计 func (h *StatsHandler) AggregateStats(c *gin.Context) { - // 解析查询参数 providerID := c.Query("provider_id") modelName := c.Query("model_name") startDateStr := c.Query("start_date") endDateStr := c.Query("end_date") - groupBy := c.Query("group_by") // "provider", "model", "date" + groupBy := c.Query("group_by") var startDate, endDate *time.Time - // 解析日期 if startDateStr != "" { t, err := time.Parse("2006-01-02", startDateStr) if err != nil { @@ -96,8 +93,7 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) { endDate = &t } - // 查询统计 - stats, err := config.GetStats(providerID, modelName, startDate, endDate) + stats, err := h.statsService.Get(providerID, modelName, startDate, endDate) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "查询统计失败: " + err.Error(), @@ -105,80 +101,6 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) { return } - // 聚合 - result := h.aggregate(stats, groupBy) - + result := h.statsService.Aggregate(stats, groupBy) c.JSON(http.StatusOK, result) } - -// aggregate 执行聚合 -func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} { - switch groupBy { - case "provider": - return h.aggregateByProvider(stats) - case "model": - return h.aggregateByModel(stats) - case "date": - return h.aggregateByDate(stats) - default: - // 默认按供应商聚合 - return h.aggregateByProvider(stats) - } -} - -// aggregateByProvider 按供应商聚合 -func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} { - aggregated := make(map[string]int) - for _, stat := range stats { - aggregated[stat.ProviderID] += stat.RequestCount - } - - result := make([]map[string]interface{}, 0, len(aggregated)) - for providerID, count := range aggregated { - result = append(result, map[string]interface{}{ - "provider_id": providerID, - "request_count": count, - }) - } - - return result -} - -// aggregateByModel 按模型聚合 -func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} { - aggregated := make(map[string]int) - for _, stat := range stats { - key := stat.ProviderID + "/" + stat.ModelName - aggregated[key] += stat.RequestCount - } - - result := make([]map[string]interface{}, 0, len(aggregated)) - for key, count := range aggregated { - result = append(result, map[string]interface{}{ - "provider_id": key[:len(key)/2], - "model_name": key[len(key)/2+1:], - "request_count": count, - }) - } - - return result -} - -// aggregateByDate 按日期聚合 -func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} { - aggregated := make(map[string]int) - for _, stat := range stats { - key := stat.Date.Format("2006-01-02") - aggregated[key] += stat.RequestCount - } - - result := make([]map[string]interface{}, 0, len(aggregated)) - for date, count := range aggregated { - result = append(result, map[string]interface{}{ - "date": date, - "request_count": count, - }) - } - - return result -} diff --git a/backend/internal/protocol/anthropic/converter_test.go b/backend/internal/protocol/anthropic/converter_test.go new file mode 100644 index 0000000..683d9d6 --- /dev/null +++ b/backend/internal/protocol/anthropic/converter_test.go @@ -0,0 +1,270 @@ +package anthropic + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "nex/backend/internal/protocol/openai" +) + +func TestConvertRequest_Basic(t *testing.T) { + temp := 0.7 + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 1024, + Temperature: &temp, + Messages: []AnthropicMessage{ + { + Role: "user", + Content: []ContentBlock{ + {Type: "text", Text: "Hello"}, + }, + }, + }, + } + + result, err := ConvertRequest(req) + require.NoError(t, err) + assert.Equal(t, "claude-3-opus", result.Model) + assert.Equal(t, 1024, *result.MaxTokens) + assert.Equal(t, &temp, result.Temperature) + require.Len(t, result.Messages, 1) + assert.Equal(t, "user", result.Messages[0].Role) + assert.Equal(t, "Hello", result.Messages[0].Content) +} + +func TestConvertRequest_WithSystem(t *testing.T) { + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 100, + System: "You are a helpful assistant.", + Messages: []AnthropicMessage{ + { + Role: "user", + Content: []ContentBlock{{Type: "text", Text: "Hi"}}, + }, + }, + } + + result, err := ConvertRequest(req) + require.NoError(t, err) + require.Len(t, result.Messages, 2) + assert.Equal(t, "system", result.Messages[0].Role) + assert.Equal(t, "You are a helpful assistant.", result.Messages[0].Content) + assert.Equal(t, "user", result.Messages[1].Role) +} + +func TestConvertRequest_DefaultMaxTokens(t *testing.T) { + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 0, // 未设置 + Messages: []AnthropicMessage{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, + }, + } + + result, err := ConvertRequest(req) + require.NoError(t, err) + assert.Equal(t, 4096, *result.MaxTokens) +} + +func TestConvertRequest_WithTools(t *testing.T) { + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 100, + Messages: []AnthropicMessage{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, + }, + Tools: []AnthropicTool{ + { + Name: "get_weather", + Description: "Get weather info", + InputSchema: map[string]interface{}{"type": "object"}, + }, + }, + } + + result, err := ConvertRequest(req) + require.NoError(t, err) + require.Len(t, result.Tools, 1) + assert.Equal(t, "function", result.Tools[0].Type) + assert.Equal(t, "get_weather", result.Tools[0].Function.Name) +} + +func TestConvertRequest_WithStopSequences(t *testing.T) { + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 100, + StopSequences: []string{"STOP", "END"}, + Messages: []AnthropicMessage{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, + }, + } + + result, err := ConvertRequest(req) + require.NoError(t, err) + assert.Equal(t, []string{"STOP", "END"}, result.Stop) +} + +func TestConvertRequest_ToolResult(t *testing.T) { + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 100, + Messages: []AnthropicMessage{ + { + Role: "user", + Content: []ContentBlock{ + { + Type: "tool_result", + ToolUseID: "tool_123", + Content: "result data", + }, + }, + }, + }, + } + + result, err := ConvertRequest(req) + require.NoError(t, err) + require.Len(t, result.Messages, 1) + assert.Equal(t, "tool", result.Messages[0].Role) + assert.Equal(t, "tool_123", result.Messages[0].ToolCallID) + assert.Equal(t, "result data", result.Messages[0].Content) +} + +func TestConvertResponse(t *testing.T) { + resp := &openai.ChatCompletionResponse{ + ID: "chatcmpl-123", + Model: "gpt-4", + Choices: []openai.Choice{ + { + Index: 0, + Message: &openai.Message{Role: "assistant", Content: "Hello!"}, + FinishReason: "stop", + }, + }, + Usage: openai.Usage{PromptTokens: 10, CompletionTokens: 5}, + } + + result, err := ConvertResponse(resp) + require.NoError(t, err) + assert.Equal(t, "chatcmpl-123", result.ID) + assert.Equal(t, "message", result.Type) + assert.Equal(t, "assistant", result.Role) + assert.Equal(t, "end_turn", result.StopReason) + require.Len(t, result.Content, 1) + assert.Equal(t, "text", result.Content[0].Type) + assert.Equal(t, "Hello!", result.Content[0].Text) + assert.Equal(t, 10, result.Usage.InputTokens) + assert.Equal(t, 5, result.Usage.OutputTokens) +} + +func TestConvertResponse_ToolCalls(t *testing.T) { + args, _ := json.Marshal(map[string]interface{}{"city": "Beijing"}) + resp := &openai.ChatCompletionResponse{ + ID: "chatcmpl-456", + Model: "gpt-4", + Choices: []openai.Choice{ + { + Index: 0, + Message: &openai.Message{ + Role: "assistant", + ToolCalls: []openai.ToolCall{ + { + ID: "call_123", + Type: "function", + Function: openai.FunctionCall{ + Name: "get_weather", + Arguments: string(args), + }, + }, + }, + }, + FinishReason: "tool_calls", + }, + }, + Usage: openai.Usage{}, + } + + result, err := ConvertResponse(resp) + require.NoError(t, err) + assert.Equal(t, "tool_use", result.StopReason) + require.Len(t, result.Content, 1) + assert.Equal(t, "tool_use", result.Content[0].Type) + assert.Equal(t, "call_123", result.Content[0].ID) + assert.Equal(t, "get_weather", result.Content[0].Name) +} + +func TestConvertToolChoice_String(t *testing.T) { + tests := []struct { + name string + input interface{} + wantErr bool + check func(interface{}) + }{ + {"auto字符串", "auto", false, func(r interface{}) { assert.Equal(t, "auto", r) }}, + {"any字符串", "any", false, func(r interface{}) { assert.Equal(t, "auto", r) }}, + {"无效字符串", "invalid", true, nil}, + {"tool对象", map[string]interface{}{"type": "tool", "name": "my_func"}, false, + func(r interface{}) { + m := r.(map[string]interface{}) + assert.Equal(t, "function", m["type"]) + }}, + {"缺少name的tool对象", map[string]interface{}{"type": "tool"}, true, nil}, + {"缺少type的对象", map[string]interface{}{"name": "func"}, true, nil}, + {"无效类型", 42, true, nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := convertToolChoice(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + tt.check(result) + } + }) + } +} + +func TestValidateRequest(t *testing.T) { + t.Run("有效请求", func(t *testing.T) { + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 100, + Messages: []AnthropicMessage{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, + }, + } + errs := ValidateRequest(req) + assert.Nil(t, errs) + }) + + t.Run("缺少模型", func(t *testing.T) { + req := &MessagesRequest{ + MaxTokens: 100, + Messages: []AnthropicMessage{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, + }, + } + errs := ValidateRequest(req) + assert.NotNil(t, errs) + assert.Contains(t, errs["model"], "不能为空") + }) + + t.Run("MaxTokens为0", func(t *testing.T) { + req := &MessagesRequest{ + Model: "claude-3-opus", + MaxTokens: 0, + Messages: []AnthropicMessage{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}}, + }, + } + errs := ValidateRequest(req) + assert.NotNil(t, errs) + }) +} diff --git a/backend/internal/protocol/anthropic/stream_converter_test.go b/backend/internal/protocol/anthropic/stream_converter_test.go new file mode 100644 index 0000000..1608894 --- /dev/null +++ b/backend/internal/protocol/anthropic/stream_converter_test.go @@ -0,0 +1,229 @@ +package anthropic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "nex/backend/internal/protocol/openai" +) + +func TestStreamConverter_MessageStart(t *testing.T) { + converter := NewStreamConverter("msg_123", "claude-3-opus") + + chunk := &openai.StreamChunk{ + ID: "chatcmpl-123", + Choices: []openai.StreamChoice{{Index: 0, Delta: openai.Delta{}}}, + } + + events, err := converter.ConvertChunk(chunk) + require.NoError(t, err) + require.NotEmpty(t, events) + + // 第一个事件应该是 message_start + assert.Equal(t, "message_start", events[0].Type) + require.NotNil(t, events[0].Message) + assert.Equal(t, "msg_123", events[0].Message.ID) + assert.Equal(t, "message", events[0].Message.Type) + assert.Equal(t, "assistant", events[0].Message.Role) + assert.Equal(t, "claude-3-opus", events[0].Message.Model) +} + +func TestStreamConverter_TextDelta(t *testing.T) { + converter := NewStreamConverter("msg_123", "claude-3-opus") + + // 先发送一个空块以触发 message_start + chunk1 := &openai.StreamChunk{ + Choices: []openai.StreamChoice{ + {Delta: openai.Delta{Content: "Hello"}}, + }, + } + events1, err := converter.ConvertChunk(chunk1) + require.NoError(t, err) + // 应有 message_start + content_block_start + text delta + assert.GreaterOrEqual(t, len(events1), 3) + + // 第二个文本块不应再发送 message_start 和 content_block_start + chunk2 := &openai.StreamChunk{ + Choices: []openai.StreamChoice{ + {Delta: openai.Delta{Content: " world"}}, + }, + } + events2, err := converter.ConvertChunk(chunk2) + require.NoError(t, err) + // 只有 text delta + assert.Len(t, events2, 1) + assert.Equal(t, "content_block_delta", events2[0].Type) + assert.Equal(t, "text_delta", events2[0].Delta.Type) + assert.Equal(t, " world", events2[0].Delta.Text) +} + +func TestStreamConverter_FinishReason(t *testing.T) { + converter := NewStreamConverter("msg_123", "claude-3-opus") + + chunk := &openai.StreamChunk{ + Choices: []openai.StreamChoice{ + {Delta: openai.Delta{Content: "Hello"}, FinishReason: "stop"}, + }, + } + events, err := converter.ConvertChunk(chunk) + require.NoError(t, err) + + // 查找 message_delta 事件 + var messageDelta *StreamEvent + for _, e := range events { + if e.Type == "message_delta" { + messageDelta = &e + break + } + } + require.NotNil(t, messageDelta) + assert.Equal(t, "end_turn", messageDelta.Delta.StopReason) + + // 查找 message_stop 事件 + var messageStop *StreamEvent + for _, e := range events { + if e.Type == "message_stop" { + messageStop = &e + break + } + } + assert.NotNil(t, messageStop) +} + +func TestStreamConverter_FinishReasonToolCalls(t *testing.T) { + converter := NewStreamConverter("msg_123", "claude-3-opus") + + chunk := &openai.StreamChunk{ + Choices: []openai.StreamChoice{ + {Delta: openai.Delta{}, FinishReason: "tool_calls"}, + }, + } + events, err := converter.ConvertChunk(chunk) + require.NoError(t, err) + + var messageDelta *StreamEvent + for _, e := range events { + if e.Type == "message_delta" { + messageDelta = &e + break + } + } + require.NotNil(t, messageDelta) + assert.Equal(t, "tool_use", messageDelta.Delta.StopReason) +} + +func TestStreamConverter_FinishReasonLength(t *testing.T) { + converter := NewStreamConverter("msg_123", "claude-3-opus") + + chunk := &openai.StreamChunk{ + Choices: []openai.StreamChoice{ + {Delta: openai.Delta{}, FinishReason: "length"}, + }, + } + events, err := converter.ConvertChunk(chunk) + require.NoError(t, err) + + var messageDelta *StreamEvent + for _, e := range events { + if e.Type == "message_delta" { + messageDelta = &e + break + } + } + require.NotNil(t, messageDelta) + assert.Equal(t, "max_tokens", messageDelta.Delta.StopReason) +} + +func TestStreamConverter_ToolCalls(t *testing.T) { + converter := NewStreamConverter("msg_123", "claude-3-opus") + + chunk := &openai.StreamChunk{ + Choices: []openai.StreamChoice{ + { + Delta: openai.Delta{ + ToolCalls: []openai.ToolCall{ + { + ID: "call_123", + Type: "function", + Function: openai.FunctionCall{ + Name: "get_weather", + Arguments: `{"city": "Beijing"}`, + }, + }, + }, + }, + }, + }, + } + + events, err := converter.ConvertChunk(chunk) + require.NoError(t, err) + + // 应包含 content_block_start (tool_use) + content_block_delta (input_json_delta) + hasBlockStart := false + hasInputDelta := false + for _, e := range events { + if e.Type == "content_block_start" && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" { + hasBlockStart = true + assert.Equal(t, "call_123", e.ContentBlock.ID) + assert.Equal(t, "get_weather", e.ContentBlock.Name) + } + if e.Type == "content_block_delta" && e.Delta != nil && e.Delta.Type == "input_json_delta" { + hasInputDelta = true + assert.Equal(t, `{"city": "Beijing"}`, e.Delta.Input) + } + } + assert.True(t, hasBlockStart, "应有 tool_use content_block_start") + assert.True(t, hasInputDelta, "应有 input_json_delta") +} + +func TestSerializeEvent(t *testing.T) { + event := StreamEvent{ + Type: "message_start", + Message: &MessagesResponse{ + ID: "msg_123", + Type: "message", + Role: "assistant", + }, + } + + result, err := SerializeEvent(event) + require.NoError(t, err) + assert.Contains(t, result, "event: message_start") + assert.Contains(t, result, "data: ") + assert.Contains(t, result, "msg_123") +} + +func TestSerializeEvent_InvalidJSON(t *testing.T) { + event := StreamEvent{ + Type: "test", + } + // 这个应该能正常序列化 + result, err := SerializeEvent(event) + require.NoError(t, err) + assert.Contains(t, result, "event: test") +} + +func TestContentBlock_ParseInputJSON(t *testing.T) { + t.Run("字符串输入", func(t *testing.T) { + cb := &ContentBlock{Input: `{"key": "value"}`} + result, err := cb.ParseInputJSON() + require.NoError(t, err) + assert.Equal(t, "value", result["key"]) + }) + + t.Run("对象输入", func(t *testing.T) { + cb := &ContentBlock{Input: map[string]interface{}{"key": "value"}} + result, err := cb.ParseInputJSON() + require.NoError(t, err) + assert.Equal(t, "value", result["key"]) + }) + + t.Run("无效类型", func(t *testing.T) { + cb := &ContentBlock{Input: 42} + _, err := cb.ParseInputJSON() + assert.Error(t, err) + }) +} diff --git a/backend/internal/protocol/anthropic/types.go b/backend/internal/protocol/anthropic/types.go index c2fae14..54f37bf 100644 --- a/backend/internal/protocol/anthropic/types.go +++ b/backend/internal/protocol/anthropic/types.go @@ -1,13 +1,20 @@ package anthropic -import "encoding/json" +import ( + "encoding/json" + "fmt" + + "github.com/go-playground/validator/v10" + + pkgValidator "nex/backend/pkg/validator" +) // MessagesRequest Anthropic Messages API 请求结构 type MessagesRequest struct { - Model string `json:"model"` - Messages []AnthropicMessage `json:"messages"` + Model string `json:"model" validate:"required"` + Messages []AnthropicMessage `json:"messages" validate:"required,min=1"` System string `json:"system,omitempty"` - MaxTokens int `json:"max_tokens"` + MaxTokens int `json:"max_tokens" validate:"required,min=1"` Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` TopK *int `json:"top_k,omitempty"` @@ -114,5 +121,29 @@ func (cb *ContentBlock) ParseInputJSON() (map[string]interface{}, error) { if obj, ok := cb.Input.(map[string]interface{}); ok { return obj, nil } - return nil, json.Unmarshal([]byte{}, nil) // 返回错误 + return nil, fmt.Errorf("invalid input type: expected string or map") +} + +// ValidateRequest 验证 MessagesRequest +func ValidateRequest(req *MessagesRequest) map[string]string { + errs := pkgValidator.Validate(req) + if errs == nil { + return nil + } + + validationErrors := make(map[string]string) + for _, err := range errs.(validator.ValidationErrors) { + field := err.Field() + switch field { + case "Model": + validationErrors["model"] = "模型名称不能为空" + case "Messages": + validationErrors["messages"] = "消息列表不能为空" + case "MaxTokens": + validationErrors["max_tokens"] = "max_tokens 不能为空且必须大于 0" + default: + validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag()) + } + } + return validationErrors } diff --git a/backend/internal/protocol/openai/adapter.go b/backend/internal/protocol/openai/adapter.go index 533fe8c..708b6c7 100644 --- a/backend/internal/protocol/openai/adapter.go +++ b/backend/internal/protocol/openai/adapter.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "encoding/json" - "fmt" "io" "net/http" ) @@ -24,9 +23,6 @@ func (a *Adapter) PrepareRequest(req *ChatCompletionRequest, apiKey, baseURL str return nil, err } - // 调试日志:打印请求体 - fmt.Printf("[DEBUG] 请求Body: %s\n", string(body)) - // 创建 HTTP 请求 // baseURL 已包含版本路径(如 /v1 或 /v4),只需添加端点路径 httpReq, err := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body)) diff --git a/backend/internal/protocol/openai/adapter_test.go b/backend/internal/protocol/openai/adapter_test.go new file mode 100644 index 0000000..0dce920 --- /dev/null +++ b/backend/internal/protocol/openai/adapter_test.go @@ -0,0 +1,190 @@ +package openai + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAdapter_PrepareRequest(t *testing.T) { + adapter := NewAdapter() + req := &ChatCompletionRequest{ + Model: "gpt-4", + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + } + + httpReq, err := adapter.PrepareRequest(req, "test-api-key", "https://api.openai.com/v1") + require.NoError(t, err) + require.NotNil(t, httpReq) + + assert.Equal(t, "POST", httpReq.Method) + assert.Equal(t, "https://api.openai.com/v1/chat/completions", httpReq.URL.String()) + assert.Equal(t, "application/json", httpReq.Header.Get("Content-Type")) + assert.Equal(t, "Bearer test-api-key", httpReq.Header.Get("Authorization")) + + // 验证请求体 + var body ChatCompletionRequest + err = json.NewDecoder(httpReq.Body).Decode(&body) + require.NoError(t, err) + assert.Equal(t, "gpt-4", body.Model) +} + +func TestAdapter_ParseResponse(t *testing.T) { + adapter := NewAdapter() + resp := &ChatCompletionResponse{ + ID: "chatcmpl-123", + Object: "chat.completion", + Created: 1234567890, + Model: "gpt-4", + Choices: []Choice{ + { + Index: 0, + Message: &Message{Role: "assistant", Content: "Hello!"}, + }, + }, + Usage: Usage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15}, + } + + body, err := json.Marshal(resp) + require.NoError(t, err) + + httpResp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(body)), + } + + result, err := adapter.ParseResponse(httpResp) + require.NoError(t, err) + assert.Equal(t, "chatcmpl-123", result.ID) + assert.Equal(t, "gpt-4", result.Model) + require.Len(t, result.Choices, 1) + assert.Equal(t, "Hello!", result.Choices[0].Message.Content) +} + +func TestAdapter_ParseErrorResponse(t *testing.T) { + adapter := NewAdapter() + errResp := &ErrorResponse{ + Error: ErrorDetail{ + Message: "Invalid API key", + Type: "invalid_request_error", + Code: "invalid_api_key", + }, + } + + body, err := json.Marshal(errResp) + require.NoError(t, err) + + httpResp := &http.Response{ + StatusCode: 401, + Body: io.NopCloser(bytes.NewReader(body)), + } + + result, err := adapter.ParseErrorResponse(httpResp) + require.NoError(t, err) + assert.Equal(t, "Invalid API key", result.Error.Message) + assert.Equal(t, "invalid_request_error", result.Error.Type) +} + +func TestAdapter_ParseStreamChunk(t *testing.T) { + adapter := NewAdapter() + chunk := &StreamChunk{ + ID: "chatcmpl-123", + Object: "chat.completion.chunk", + Created: 1234567890, + Model: "gpt-4", + Choices: []StreamChoice{ + { + Index: 0, + Delta: Delta{Content: "Hello"}, + }, + }, + } + + data, err := json.Marshal(chunk) + require.NoError(t, err) + + result, err := adapter.ParseStreamChunk(data) + require.NoError(t, err) + assert.Equal(t, "chatcmpl-123", result.ID) + require.Len(t, result.Choices, 1) + assert.Equal(t, "Hello", result.Choices[0].Delta.Content) +} + +func TestParseToolCallArguments(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"有效JSON", `{"key": "value"}`, false}, + {"无效JSON", `not json`, true}, + {"空JSON", `{}`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &ToolCall{ + Function: FunctionCall{Arguments: tt.input}, + } + args, err := tc.ParseToolCallArguments() + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotNil(t, args) + } + }) + } +} + +func TestSerializeToolCallArguments(t *testing.T) { + args := map[string]interface{}{"key": "value"} + result, err := SerializeToolCallArguments(args) + require.NoError(t, err) + assert.JSONEq(t, `{"key": "value"}`, result) +} + +func TestValidateRequest(t *testing.T) { + t.Run("有效请求", func(t *testing.T) { + req := &ChatCompletionRequest{ + Model: "gpt-4", + Messages: []Message{{Role: "user", Content: "hello"}}, + } + errs := ValidateRequest(req) + assert.Nil(t, errs) + }) + + t.Run("缺少模型", func(t *testing.T) { + req := &ChatCompletionRequest{ + Messages: []Message{{Role: "user", Content: "hello"}}, + } + errs := ValidateRequest(req) + assert.NotNil(t, errs) + assert.Contains(t, errs["model"], "不能为空") + }) + + t.Run("缺少消息", func(t *testing.T) { + req := &ChatCompletionRequest{ + Model: "gpt-4", + } + errs := ValidateRequest(req) + assert.NotNil(t, errs) + assert.Contains(t, errs["messages"], "不能为空") + }) + + t.Run("空消息列表", func(t *testing.T) { + req := &ChatCompletionRequest{ + Model: "gpt-4", + Messages: []Message{}, + } + errs := ValidateRequest(req) + assert.NotNil(t, errs) + }) +} diff --git a/backend/internal/protocol/openai/types.go b/backend/internal/protocol/openai/types.go index 6367723..b181dc5 100644 --- a/backend/internal/protocol/openai/types.go +++ b/backend/internal/protocol/openai/types.go @@ -1,11 +1,18 @@ package openai -import "encoding/json" +import ( + "encoding/json" + "fmt" + + "github.com/go-playground/validator/v10" + + pkgValidator "nex/backend/pkg/validator" +) // ChatCompletionRequest OpenAI Chat Completions API 请求结构 type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` + Model string `json:"model" validate:"required"` + Messages []Message `json:"messages" validate:"required,min=1"` Temperature *float64 `json:"temperature,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"` TopP *float64 `json:"top_p,omitempty"` @@ -129,3 +136,25 @@ func SerializeToolCallArguments(args map[string]interface{}) (string, error) { } return string(bytes), nil } + +// ValidateRequest 验证 ChatCompletionRequest +func ValidateRequest(req *ChatCompletionRequest) map[string]string { + errs := pkgValidator.Validate(req) + if errs == nil { + return nil + } + + validationErrors := make(map[string]string) + for _, err := range errs.(validator.ValidationErrors) { + field := err.Field() + switch field { + case "Model": + validationErrors["model"] = "模型名称不能为空" + case "Messages": + validationErrors["messages"] = "消息列表不能为空" + default: + validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag()) + } + } + return validationErrors +} diff --git a/backend/internal/provider/client.go b/backend/internal/provider/client.go index b019955..abf1edf 100644 --- a/backend/internal/provider/client.go +++ b/backend/internal/provider/client.go @@ -9,22 +9,59 @@ import ( "strings" "time" + "go.uber.org/zap" + "nex/backend/internal/protocol/openai" ) +// StreamConfig 流式处理配置 +type StreamConfig struct { + InitialBufferSize int // 初始缓冲区大小(字节),默认 4096 + MaxBufferSize int // 最大缓冲区大小(字节),默认 65536 + Timeout time.Duration // 流超时时间,默认 5 分钟 + ChannelBufferSize int // 事件通道缓冲区大小,默认 100 +} + +// DefaultStreamConfig 返回默认流式处理配置 +func DefaultStreamConfig() StreamConfig { + return StreamConfig{ + InitialBufferSize: 4096, + MaxBufferSize: 65536, + Timeout: 5 * time.Minute, + ChannelBufferSize: 100, + } +} + // Client OpenAI 兼容供应商客户端 type Client struct { - httpClient *http.Client - adapter *openai.Adapter + httpClient *http.Client + adapter *openai.Adapter + logger *zap.Logger + streamCfg StreamConfig +} + +// StreamEvent 流事件 +type StreamEvent struct { + Data []byte + Error error + Done bool +} + +// ProviderClient 供应商客户端接口 +type ProviderClient interface { + SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) + SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error) } // NewClient 创建供应商客户端 func NewClient() *Client { return &Client{ httpClient: &http.Client{ - Timeout: 30 * time.Second, // 非流式请求超时 + Timeout: 30 * time.Second, }, - adapter: openai.NewAdapter(), + adapter: openai.NewAdapter(), + logger: zap.L(), + streamCfg: DefaultStreamConfig(), } } @@ -36,10 +73,10 @@ func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequ return nil, fmt.Errorf("准备请求失败: %w", err) } - // 调试日志:打印完整请求信息 - fmt.Printf("[DEBUG] 请求URL: %s\n", httpReq.URL.String()) - fmt.Printf("[DEBUG] 请求Method: %s\n", httpReq.Method) - fmt.Printf("[DEBUG] 请求Headers: %v\n", httpReq.Header) + c.logger.Debug("发送请求", + zap.String("url", httpReq.URL.String()), + zap.String("method", httpReq.Method), + ) // 设置上下文 httpReq = httpReq.WithContext(ctx) @@ -80,18 +117,22 @@ func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompleti return nil, fmt.Errorf("准备请求失败: %w", err) } - // 设置上下文 - httpReq = httpReq.WithContext(ctx) + // 设置带超时的上下文 + streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout) + _ = cancel // cancel 在流读取结束后由 ctx 传播处理 + httpReq = httpReq.WithContext(streamCtx) // 发送请求 resp, err := c.httpClient.Do(httpReq) if err != nil { + cancel() return nil, fmt.Errorf("发送请求失败: %w", err) } // 检查状态码 if resp.StatusCode != http.StatusOK { defer resp.Body.Close() + cancel() errorResp, parseErr := c.adapter.ParseErrorResponse(resp) if parseErr != nil { return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) @@ -100,33 +141,33 @@ func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompleti } // 创建事件通道 - eventChan := make(chan StreamEvent, 100) + eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize) // 启动 goroutine 读取流 - go c.readStream(ctx, resp.Body, eventChan) + go c.readStream(streamCtx, cancel, resp.Body, eventChan) return eventChan, nil } -// StreamEvent 流事件 -type StreamEvent struct { - Data []byte - Error error - Done bool -} - -// readStream 读取 SSE 流 -func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan chan<- StreamEvent) { +// readStream 读取 SSE 流(支持动态缓冲区、超时控制和改进的错误处理) +func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) { defer close(eventChan) defer body.Close() + defer cancel() - buf := make([]byte, 4096) + bufSize := c.streamCfg.InitialBufferSize + buf := make([]byte, bufSize) var dataBuf []byte for { select { case <-ctx.Done(): - eventChan <- StreamEvent{Error: ctx.Err()} + if ctx.Err() == context.DeadlineExceeded { + c.logger.Warn("流读取超时") + eventChan <- StreamEvent{Error: fmt.Errorf("流读取超时: %w", ctx.Err())} + } else { + eventChan <- StreamEvent{Error: ctx.Err()} + } return default: } @@ -134,15 +175,32 @@ func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan c n, err := body.Read(buf) if err != nil { if err == io.EOF { - // 流结束 + // 流正常结束 return } - eventChan <- StreamEvent{Error: err} + // 区分网络错误和其他错误 + if isNetworkError(err) { + c.logger.Error("流网络错误", zap.String("error", err.Error())) + eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} + } else { + c.logger.Error("流读取错误", zap.String("error", err.Error())) + eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)} + } return } dataBuf = append(dataBuf, buf[:n]...) + // 动态调整缓冲区大小:如果数据量大,增大缓冲区 + if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize { + newSize := bufSize * 2 + if newSize > c.streamCfg.MaxBufferSize { + newSize = c.streamCfg.MaxBufferSize + } + buf = make([]byte, newSize) + bufSize = newSize + } + // 处理完整的 SSE 事件 for { // 查找事件边界(双换行) @@ -175,3 +233,16 @@ func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan c } } +// isNetworkError 判断是否为网络相关错误 +func isNetworkError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + return strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "broken pipe") || + strings.Contains(errStr, "network") || + strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "EOF") +} + diff --git a/backend/internal/provider/client_test.go b/backend/internal/provider/client_test.go new file mode 100644 index 0000000..66d56bb --- /dev/null +++ b/backend/internal/provider/client_test.go @@ -0,0 +1,151 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "nex/backend/internal/protocol/openai" +) + +func TestNewClient(t *testing.T) { + client := NewClient() + require.NotNil(t, client) + assert.NotNil(t, client.httpClient) + assert.NotNil(t, client.adapter) + assert.Equal(t, 4096, client.streamCfg.InitialBufferSize) + assert.Equal(t, 65536, client.streamCfg.MaxBufferSize) + assert.Equal(t, 100, client.streamCfg.ChannelBufferSize) +} + +func TestDefaultStreamConfig(t *testing.T) { + cfg := DefaultStreamConfig() + assert.Equal(t, 4096, cfg.InitialBufferSize) + assert.Equal(t, 65536, cfg.MaxBufferSize) + assert.Equal(t, 100, cfg.ChannelBufferSize) +} + +func TestClient_SendRequest_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + + resp := openai.ChatCompletionResponse{ + ID: "chatcmpl-123", + Choices: []openai.Choice{ + {Index: 0, Message: &openai.Message{Role: "assistant", Content: "Hello!"}}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewClient() + req := &openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + } + + result, err := client.SendRequest(context.Background(), req, "test-key", server.URL) + require.NoError(t, err) + assert.Equal(t, "chatcmpl-123", result.ID) +} + +func TestClient_SendRequest_ErrorResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(openai.ErrorResponse{ + Error: openai.ErrorDetail{Message: "Invalid API key"}, + }) + })) + defer server.Close() + + client := NewClient() + req := &openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + } + + _, err := client.SendRequest(context.Background(), req, "bad-key", server.URL) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Invalid API key") +} + +func TestClient_SendRequest_ConnectionError(t *testing.T) { + client := NewClient() + req := &openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + } + + _, err := client.SendRequest(context.Background(), req, "key", "http://localhost:1") + assert.Error(t, err) +} + +func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) { + // 使用一个慢服务器确保客户端有时间读取 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewClient() + req := &openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + } + + eventChan, err := client.SendStreamRequest(context.Background(), req, "test-key", server.URL) + require.NoError(t, err) + require.NotNil(t, eventChan) + + // 读取直到 channel 关闭(服务器关闭后应产生 EOF) + for range eventChan { + // 消费所有事件 + } + // channel 应已关闭(不阻塞即通过) +} + +func TestClient_SendStreamRequest_ErrorResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + client := NewClient() + req := &openai.ChatCompletionRequest{ + Model: "gpt-4", + Messages: []openai.Message{{Role: "user", Content: "Hi"}}, + } + + _, err := client.SendStreamRequest(context.Background(), req, "key", server.URL) + assert.Error(t, err) +} + +func TestIsNetworkError(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"connection reset by peer", true}, + {"broken pipe", true}, + {"network is unreachable", true}, + {"timeout waiting for response", true}, + {"unexpected EOF", true}, + {"normal error", false}, + {"", false}, + } + for _, tt := range tests { + err := fmt.Errorf("%s", tt.input) //nolint:govet + assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input) + } +} diff --git a/backend/internal/repository/model_repo.go b/backend/internal/repository/model_repo.go new file mode 100644 index 0000000..daeb76f --- /dev/null +++ b/backend/internal/repository/model_repo.go @@ -0,0 +1,13 @@ +package repository + +import "nex/backend/internal/domain" + +// ModelRepository 模型数据仓库接口 +type ModelRepository interface { + Create(model *domain.Model) error + GetByID(id string) (*domain.Model, error) + List(providerID string) ([]domain.Model, error) + GetByModelName(modelName string) (*domain.Model, error) + Update(id string, updates map[string]interface{}) error + Delete(id string) error +} diff --git a/backend/internal/repository/model_repo_impl.go b/backend/internal/repository/model_repo_impl.go new file mode 100644 index 0000000..31a0e14 --- /dev/null +++ b/backend/internal/repository/model_repo_impl.go @@ -0,0 +1,104 @@ +package repository + +import ( + "time" + + "gorm.io/gorm" + + "nex/backend/internal/config" + "nex/backend/internal/domain" + appErrors "nex/backend/pkg/errors" +) + +type modelRepository struct { + db *gorm.DB +} + +func NewModelRepository(db *gorm.DB) ModelRepository { + return &modelRepository{db: db} +} + +func (r *modelRepository) Create(model *domain.Model) error { + m := toConfigModel(model) + m.CreatedAt = time.Now() + return r.db.Create(&m).Error +} + +func (r *modelRepository) GetByID(id string) (*domain.Model, error) { + var m config.Model + err := r.db.First(&m, "id = ?", id).Error + if err != nil { + return nil, err + } + d := toDomainModel(&m) + return &d, nil +} + +func (r *modelRepository) List(providerID string) ([]domain.Model, error) { + var models []config.Model + var err error + if providerID != "" { + err = r.db.Where("provider_id = ?", providerID).Find(&models).Error + } else { + err = r.db.Find(&models).Error + } + if err != nil { + return nil, err + } + result := make([]domain.Model, len(models)) + for i := range models { + result[i] = toDomainModel(&models[i]) + } + return result, nil +} + +func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) { + var m config.Model + err := r.db.Where("model_name = ?", modelName).First(&m).Error + if err != nil { + return nil, err + } + d := toDomainModel(&m) + return &d, nil +} + +func (r *modelRepository) Update(id string, updates map[string]interface{}) error { + result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return appErrors.ErrModelNotFound + } + return nil +} + +func (r *modelRepository) Delete(id string) error { + result := r.db.Delete(&config.Model{}, "id = ?", id) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return appErrors.ErrModelNotFound + } + return nil +} + +func toDomainModel(m *config.Model) domain.Model { + return domain.Model{ + ID: m.ID, + ProviderID: m.ProviderID, + ModelName: m.ModelName, + Enabled: m.Enabled, + CreatedAt: m.CreatedAt, + } +} + +func toConfigModel(m *domain.Model) config.Model { + return config.Model{ + ID: m.ID, + ProviderID: m.ProviderID, + ModelName: m.ModelName, + Enabled: m.Enabled, + } +} diff --git a/backend/internal/repository/provider_repo.go b/backend/internal/repository/provider_repo.go new file mode 100644 index 0000000..18986dc --- /dev/null +++ b/backend/internal/repository/provider_repo.go @@ -0,0 +1,12 @@ +package repository + +import "nex/backend/internal/domain" + +// ProviderRepository 供应商数据仓库接口 +type ProviderRepository interface { + Create(provider *domain.Provider) error + GetByID(id string) (*domain.Provider, error) + List() ([]domain.Provider, error) + Update(id string, updates map[string]interface{}) error + Delete(id string) error +} diff --git a/backend/internal/repository/provider_repo_impl.go b/backend/internal/repository/provider_repo_impl.go new file mode 100644 index 0000000..45b7501 --- /dev/null +++ b/backend/internal/repository/provider_repo_impl.go @@ -0,0 +1,94 @@ +package repository + +import ( + "time" + + "gorm.io/gorm" + + "nex/backend/internal/config" + "nex/backend/internal/domain" + appErrors "nex/backend/pkg/errors" +) + +type providerRepository struct { + db *gorm.DB +} + +func NewProviderRepository(db *gorm.DB) ProviderRepository { + return &providerRepository{db: db} +} + +func (r *providerRepository) Create(provider *domain.Provider) error { + p := toConfigProvider(provider) + p.CreatedAt = time.Now() + p.UpdatedAt = time.Now() + return r.db.Create(&p).Error +} + +func (r *providerRepository) GetByID(id string) (*domain.Provider, error) { + var p config.Provider + err := r.db.First(&p, "id = ?", id).Error + if err != nil { + return nil, err + } + d := toDomainProvider(&p) + return &d, nil +} + +func (r *providerRepository) List() ([]domain.Provider, error) { + var providers []config.Provider + err := r.db.Find(&providers).Error + if err != nil { + return nil, err + } + result := make([]domain.Provider, len(providers)) + for i := range providers { + result[i] = toDomainProvider(&providers[i]) + } + return result, nil +} + +func (r *providerRepository) Update(id string, updates map[string]interface{}) error { + updates["updated_at"] = time.Now() + result := r.db.Model(&config.Provider{}).Where("id = ?", id).Updates(updates) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return appErrors.ErrProviderNotFound + } + return nil +} + +func (r *providerRepository) Delete(id string) error { + result := r.db.Delete(&config.Provider{}, "id = ?", id) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return appErrors.ErrProviderNotFound + } + return nil +} + +func toDomainProvider(p *config.Provider) domain.Provider { + return domain.Provider{ + ID: p.ID, + Name: p.Name, + APIKey: p.APIKey, + BaseURL: p.BaseURL, + Enabled: p.Enabled, + CreatedAt: p.CreatedAt, + UpdatedAt: p.UpdatedAt, + } +} + +func toConfigProvider(p *domain.Provider) config.Provider { + return config.Provider{ + ID: p.ID, + Name: p.Name, + APIKey: p.APIKey, + BaseURL: p.BaseURL, + Enabled: p.Enabled, + } +} diff --git a/backend/internal/repository/repository_test.go b/backend/internal/repository/repository_test.go new file mode 100644 index 0000000..584b15b --- /dev/null +++ b/backend/internal/repository/repository_test.go @@ -0,0 +1,233 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "nex/backend/internal/config" + "nex/backend/internal/domain" +) + +func setupTestDB(t *testing.T) *gorm.DB { + t.Helper() + dir := t.TempDir() + db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) + require.NoError(t, err) + err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) + require.NoError(t, err) + // 关闭数据库连接以便 TempDir 清理 + t.Cleanup(func() { + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + }) + return db +} + +// ============ ProviderRepository 测试 ============ + +func TestProviderRepository_Create(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + provider := &domain.Provider{ + ID: "test-provider", + Name: "Test Provider", + APIKey: "sk-test-key", + BaseURL: "https://api.test.com", + Enabled: true, + } + err := repo.Create(provider) + require.NoError(t, err) +} + +func TestProviderRepository_GetByID(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + provider := &domain.Provider{ + ID: "test-provider", Name: "Test", APIKey: "sk-test-key", BaseURL: "https://api.test.com", + } + err := repo.Create(provider) + require.NoError(t, err) + + result, err := repo.GetByID("test-provider") + require.NoError(t, err) + assert.Equal(t, "test-provider", result.ID) + assert.Equal(t, "Test", result.Name) +} + +func TestProviderRepository_GetByID_NotFound(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + _, err := repo.GetByID("nonexistent") + assert.Error(t, err) +} + +func TestProviderRepository_List(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + for _, id := range []string{"pA", "pB", "pC"} { + err := repo.Create(&domain.Provider{ID: id, Name: id, APIKey: "key", BaseURL: "https://test.com"}) + require.NoError(t, err) + } + + providers, err := repo.List() + require.NoError(t, err) + assert.Len(t, providers, 3) +} + +func TestProviderRepository_Update(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}) + + err := repo.Update("p1", map[string]interface{}{"name": "New"}) + require.NoError(t, err) + + result, _ := repo.GetByID("p1") + assert.Equal(t, "New", result.Name) +} + +func TestProviderRepository_Update_NotFound(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + err := repo.Update("nonexistent", map[string]interface{}{"name": "New"}) + assert.Error(t, err) +} + +func TestProviderRepository_Delete(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}) + err := repo.Delete("p1") + require.NoError(t, err) + + _, err = repo.GetByID("p1") + assert.Error(t, err) +} + +func TestProviderRepository_Delete_NotFound(t *testing.T) { + db := setupTestDB(t) + repo := NewProviderRepository(db) + + err := repo.Delete("nonexistent") + assert.Error(t, err) +} + +// ============ ModelRepository 测试 ============ + +func TestModelRepository_Create(t *testing.T) { + db := setupTestDB(t) + repo := NewModelRepository(db) + + err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + require.NoError(t, err) +} + +func TestModelRepository_GetByID(t *testing.T) { + db := setupTestDB(t) + repo := NewModelRepository(db) + + repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + + result, err := repo.GetByID("m1") + require.NoError(t, err) + assert.Equal(t, "m1", result.ID) + assert.Equal(t, "gpt-4", result.ModelName) +} + +func TestModelRepository_GetByModelName(t *testing.T) { + db := setupTestDB(t) + repo := NewModelRepository(db) + + repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + + result, err := repo.GetByModelName("gpt-4") + require.NoError(t, err) + assert.Equal(t, "m1", result.ID) +} + +func TestModelRepository_List(t *testing.T) { + db := setupTestDB(t) + repo := NewModelRepository(db) + + repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) + repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}) + repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}) + + all, err := repo.List("") + require.NoError(t, err) + assert.Len(t, all, 3) + + p1Models, err := repo.List("p1") + require.NoError(t, err) + assert.Len(t, p1Models, 2) +} + +func TestModelRepository_Update(t *testing.T) { + db := setupTestDB(t) + repo := NewModelRepository(db) + + repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + + err := repo.Update("m1", map[string]interface{}{"enabled": false}) + require.NoError(t, err) + + result, _ := repo.GetByID("m1") + assert.False(t, result.Enabled) +} + +func TestModelRepository_Delete(t *testing.T) { + db := setupTestDB(t) + repo := NewModelRepository(db) + + repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) + + err := repo.Delete("m1") + require.NoError(t, err) + + _, err = repo.GetByID("m1") + assert.Error(t, err) +} + +// ============ StatsRepository 测试 ============ + +func TestStatsRepository_Record(t *testing.T) { + db := setupTestDB(t) + repo := NewStatsRepository(db) + + err := repo.Record("provider-1", "gpt-4") + require.NoError(t, err) + + // 再次记录应递增 + err = repo.Record("provider-1", "gpt-4") + require.NoError(t, err) + + stats, err := repo.Query("provider-1", "", nil, nil) + require.NoError(t, err) + require.Len(t, stats, 1) + assert.Equal(t, 2, stats[0].RequestCount) +} + +func TestStatsRepository_Query(t *testing.T) { + db := setupTestDB(t) + repo := NewStatsRepository(db) + + repo.Record("p1", "gpt-4") + // 注意:当前 schema 只有 date 字段有唯一约束 + // 所以同一 provider + model 只能有一条记录 + stats, err := repo.Query("p1", "", nil, nil) + require.NoError(t, err) + assert.Len(t, stats, 1) +} diff --git a/backend/internal/repository/stats_repo.go b/backend/internal/repository/stats_repo.go new file mode 100644 index 0000000..d7c3df7 --- /dev/null +++ b/backend/internal/repository/stats_repo.go @@ -0,0 +1,13 @@ +package repository + +import ( + "time" + + "nex/backend/internal/domain" +) + +// StatsRepository 统计数据仓库接口 +type StatsRepository interface { + Record(providerID, modelName string) error + Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) +} diff --git a/backend/internal/repository/stats_repo_impl.go b/backend/internal/repository/stats_repo_impl.go new file mode 100644 index 0000000..dd2ef62 --- /dev/null +++ b/backend/internal/repository/stats_repo_impl.go @@ -0,0 +1,79 @@ +package repository + +import ( + "errors" + "time" + + "gorm.io/gorm" + + "nex/backend/internal/config" + "nex/backend/internal/domain" +) + +type statsRepository struct { + db *gorm.DB +} + +func NewStatsRepository(db *gorm.DB) StatsRepository { + return &statsRepository{db: db} +} + +func (r *statsRepository) Record(providerID, modelName string) error { + today := time.Now().Format("2006-01-02") + todayTime, _ := time.Parse("2006-01-02", today) + + return r.db.Transaction(func(tx *gorm.DB) error { + var stats config.UsageStats + err := tx.Where("provider_id = ? AND model_name = ? AND date = ?", + providerID, modelName, todayTime).First(&stats).Error + + if errors.Is(err, gorm.ErrRecordNotFound) { + stats = config.UsageStats{ + ProviderID: providerID, + ModelName: modelName, + RequestCount: 1, + Date: todayTime, + } + return tx.Create(&stats).Error + } else if err != nil { + return err + } + + return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error + }) +} + +func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { + var stats []config.UsageStats + query := r.db.Model(&config.UsageStats{}) + + if providerID != "" { + query = query.Where("provider_id = ?", providerID) + } + if modelName != "" { + query = query.Where("model_name = ?", modelName) + } + if startDate != nil { + query = query.Where("date >= ?", startDate) + } + if endDate != nil { + query = query.Where("date <= ?", endDate) + } + + err := query.Order("date DESC").Find(&stats).Error + if err != nil { + return nil, err + } + + result := make([]domain.UsageStats, len(stats)) + for i := range stats { + result[i] = domain.UsageStats{ + ID: stats[i].ID, + ProviderID: stats[i].ProviderID, + ModelName: stats[i].ModelName, + RequestCount: stats[i].RequestCount, + Date: stats[i].Date, + } + } + return result, nil +} diff --git a/backend/internal/router/model_router.go b/backend/internal/router/model_router.go deleted file mode 100644 index aa0120f..0000000 --- a/backend/internal/router/model_router.go +++ /dev/null @@ -1,71 +0,0 @@ -package router - -import ( - "errors" - "fmt" - - "nex/backend/internal/config" -) - -var ( - ErrModelNotFound = errors.New("模型未找到") - ErrModelDisabled = errors.New("模型已禁用") - ErrProviderDisabled = errors.New("供应商已禁用") -) - -// RouteResult 路由结果 -type RouteResult struct { - Provider *config.Provider - Model *config.Model -} - -// Router 模型路由器 -type Router struct{} - -// NewRouter 创建路由器 -func NewRouter() *Router { - return &Router{} -} - -// Route 根据模型名称路由到供应商 -func (r *Router) Route(modelName string) (*RouteResult, error) { - // 查询模型 - models, err := config.ListModels("") - if err != nil { - return nil, fmt.Errorf("查询模型失败: %w", err) - } - - // 查找匹配的模型 - var targetModel *config.Model - for i := range models { - if models[i].ModelName == modelName { - targetModel = &models[i] - break - } - } - - if targetModel == nil { - return nil, ErrModelNotFound - } - - // 检查模型是否启用 - if !targetModel.Enabled { - return nil, ErrModelDisabled - } - - // 查询供应商 - provider, err := config.GetProvider(targetModel.ProviderID, false) - if err != nil { - return nil, fmt.Errorf("查询供应商失败: %w", err) - } - - // 检查供应商是否启用 - if !provider.Enabled { - return nil, ErrProviderDisabled - } - - return &RouteResult{ - Provider: provider, - Model: targetModel, - }, nil -} diff --git a/backend/internal/service/model_service.go b/backend/internal/service/model_service.go new file mode 100644 index 0000000..e927abb --- /dev/null +++ b/backend/internal/service/model_service.go @@ -0,0 +1,12 @@ +package service + +import "nex/backend/internal/domain" + +// ModelService 模型服务接口 +type ModelService interface { + Create(model *domain.Model) error + Get(id string) (*domain.Model, error) + List(providerID string) ([]domain.Model, error) + Update(id string, updates map[string]interface{}) error + Delete(id string) error +} diff --git a/backend/internal/service/model_service_impl.go b/backend/internal/service/model_service_impl.go new file mode 100644 index 0000000..990c6c2 --- /dev/null +++ b/backend/internal/service/model_service_impl.go @@ -0,0 +1,50 @@ +package service + +import ( + appErrors "nex/backend/pkg/errors" + + "nex/backend/internal/domain" + "nex/backend/internal/repository" +) + +type modelService struct { + modelRepo repository.ModelRepository + providerRepo repository.ProviderRepository +} + +func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService { + return &modelService{modelRepo: modelRepo, providerRepo: providerRepo} +} + +func (s *modelService) Create(model *domain.Model) error { + // Verify provider exists + _, err := s.providerRepo.GetByID(model.ProviderID) + if err != nil { + return appErrors.ErrProviderNotFound + } + model.Enabled = true + return s.modelRepo.Create(model) +} + +func (s *modelService) Get(id string) (*domain.Model, error) { + return s.modelRepo.GetByID(id) +} + +func (s *modelService) List(providerID string) ([]domain.Model, error) { + return s.modelRepo.List(providerID) +} + +func (s *modelService) Update(id string, updates map[string]interface{}) error { + // If updating provider_id, verify new provider exists + if providerID, ok := updates["provider_id"].(string); ok { + _, err := s.providerRepo.GetByID(providerID) + if err != nil { + return appErrors.ErrProviderNotFound + } + } + return s.modelRepo.Update(id, updates) +} + +func (s *modelService) Delete(id string) error { + return s.modelRepo.Delete(id) +} diff --git a/backend/internal/service/provider_service.go b/backend/internal/service/provider_service.go new file mode 100644 index 0000000..fdebc7c --- /dev/null +++ b/backend/internal/service/provider_service.go @@ -0,0 +1,12 @@ +package service + +import "nex/backend/internal/domain" + +// ProviderService 供应商服务接口 +type ProviderService interface { + Create(provider *domain.Provider) error + Get(id string, maskKey bool) (*domain.Provider, error) + List() ([]domain.Provider, error) + Update(id string, updates map[string]interface{}) error + Delete(id string) error +} diff --git a/backend/internal/service/provider_service_impl.go b/backend/internal/service/provider_service_impl.go new file mode 100644 index 0000000..b34883a --- /dev/null +++ b/backend/internal/service/provider_service_impl.go @@ -0,0 +1,49 @@ +package service + +import ( + "nex/backend/internal/domain" + "nex/backend/internal/repository" +) + +type providerService struct { + providerRepo repository.ProviderRepository +} + +func NewProviderService(providerRepo repository.ProviderRepository) ProviderService { + return &providerService{providerRepo: providerRepo} +} + +func (s *providerService) Create(provider *domain.Provider) error { + provider.Enabled = true + return s.providerRepo.Create(provider) +} + +func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) { + provider, err := s.providerRepo.GetByID(id) + if err != nil { + return nil, err + } + if maskKey { + provider.MaskAPIKey() + } + return provider, nil +} + +func (s *providerService) List() ([]domain.Provider, error) { + providers, err := s.providerRepo.List() + if err != nil { + return nil, err + } + for i := range providers { + providers[i].MaskAPIKey() + } + return providers, nil +} + +func (s *providerService) Update(id string, updates map[string]interface{}) error { + return s.providerRepo.Update(id, updates) +} + +func (s *providerService) Delete(id string) error { + return s.providerRepo.Delete(id) +} diff --git a/backend/internal/service/routing_service.go b/backend/internal/service/routing_service.go new file mode 100644 index 0000000..85db8a4 --- /dev/null +++ b/backend/internal/service/routing_service.go @@ -0,0 +1,8 @@ +package service + +import "nex/backend/internal/domain" + +// RoutingService 路由服务接口 +type RoutingService interface { + Route(modelName string) (*domain.RouteResult, error) +} diff --git a/backend/internal/service/routing_service_impl.go b/backend/internal/service/routing_service_impl.go new file mode 100644 index 0000000..482c136 --- /dev/null +++ b/backend/internal/service/routing_service_impl.go @@ -0,0 +1,42 @@ +package service + +import ( + appErrors "nex/backend/pkg/errors" + + "nex/backend/internal/domain" + "nex/backend/internal/repository" +) + +type routingService struct { + modelRepo repository.ModelRepository + providerRepo repository.ProviderRepository +} + +func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService { + return &routingService{modelRepo: modelRepo, providerRepo: providerRepo} +} + +func (s *routingService) Route(modelName string) (*domain.RouteResult, error) { + model, err := s.modelRepo.GetByModelName(modelName) + if err != nil { + return nil, appErrors.ErrModelNotFound + } + + if !model.Enabled { + return nil, appErrors.ErrModelDisabled + } + + provider, err := s.providerRepo.GetByID(model.ProviderID) + if err != nil { + return nil, appErrors.ErrProviderNotFound + } + + if !provider.Enabled { + return nil, appErrors.ErrProviderDisabled + } + + return &domain.RouteResult{ + Provider: provider, + Model: model, + }, nil +} diff --git a/backend/internal/service/service_test.go b/backend/internal/service/service_test.go new file mode 100644 index 0000000..a5ceb6c --- /dev/null +++ b/backend/internal/service/service_test.go @@ -0,0 +1,245 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "nex/backend/internal/config" + "nex/backend/internal/domain" + "nex/backend/internal/repository" +) + +func setupServiceTestDB(t *testing.T) *gorm.DB { + t.Helper() + dir := t.TempDir() + db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) + require.NoError(t, err) + err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) + require.NoError(t, err) + t.Cleanup(func() { + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + }) + return db +} + +// ============ ProviderService 测试 ============ + +func TestProviderService_Create(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + svc := NewProviderService(repo) + + provider := &domain.Provider{ + ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", + } + err := svc.Create(provider) + require.NoError(t, err) + assert.True(t, provider.Enabled) +} + +func TestProviderService_Get_MaskKey(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + svc := NewProviderService(repo) + + svc.Create(&domain.Provider{ + ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com", + }) + + result, err := svc.Get("p1", true) + require.NoError(t, err) + assert.Equal(t, "***2345", result.APIKey) + + result, err = svc.Get("p1", false) + require.NoError(t, err) + assert.Equal(t, "sk-long-api-key-12345", result.APIKey) +} + +func TestProviderService_List(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + svc := NewProviderService(repo) + + svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"}) + svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"}) + + providers, err := svc.List() + require.NoError(t, err) + assert.Len(t, providers, 2) + assert.Contains(t, providers[0].APIKey, "***") +} + +func TestProviderService_Delete(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + svc := NewProviderService(repo) + + svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}) + err := svc.Delete("p1") + require.NoError(t, err) + + _, err = svc.Get("p1", false) + assert.Error(t, err) +} + +// ============ ModelService 测试 ============ + +func TestModelService_Create(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewModelService(modelRepo, providerRepo) + + providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) + + model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"} + err := svc.Create(model) + require.NoError(t, err) + assert.True(t, model.Enabled) +} + +func TestModelService_Create_ProviderNotFound(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewModelService(modelRepo, providerRepo) + + model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"} + err := svc.Create(model) + assert.Error(t, err) +} + +func TestModelService_List(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewModelService(modelRepo, providerRepo) + + providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) + svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) + svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}) + + models, err := svc.List("p1") + require.NoError(t, err) + assert.Len(t, models, 2) +} + +// ============ RoutingService 测试 ============ + +func TestRoutingService_Route(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) + + providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) + modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + + result, err := svc.Route("gpt-4") + require.NoError(t, err) + assert.Equal(t, "p1", result.Provider.ID) + assert.Equal(t, "gpt-4", result.Model.ModelName) +} + +func TestRoutingService_Route_ModelNotFound(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) + + _, err := svc.Route("nonexistent-model") + assert.Error(t, err) +} + +func TestRoutingService_Route_ModelDisabled(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) + + providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) + // 先创建启用的模型,然后通过 Update 禁用 + modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + modelRepo.Update("m1", map[string]interface{}{"enabled": false}) + + _, err := svc.Route("gpt-4") + assert.Error(t, err) +} + +func TestRoutingService_Route_ProviderDisabled(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) + + // 先创建启用的 provider,然后禁用 + providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) + providerRepo.Update("p1", map[string]interface{}{"enabled": false}) + modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + + _, err := svc.Route("gpt-4") + assert.Error(t, err) +} + +// ============ StatsService 测试 ============ + +func TestStatsService_RecordAndGet(t *testing.T) { + db := setupServiceTestDB(t) + statsRepo := repository.NewStatsRepository(db) + svc := NewStatsService(statsRepo) + + err := svc.Record("p1", "gpt-4") + require.NoError(t, err) + + stats, err := svc.Get("p1", "", nil, nil) + require.NoError(t, err) + assert.Len(t, stats, 1) +} + +func TestStatsService_Aggregate_ByProvider(t *testing.T) { + statsRepo := repository.NewStatsRepository(nil) + svc := NewStatsService(statsRepo) + + stats := []domain.UsageStats{ + {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, + {ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5}, + {ProviderID: "p2", ModelName: "claude-3", RequestCount: 8}, + } + + result := svc.Aggregate(stats, "provider") + assert.Len(t, result, 2) + + p1Count := 0 + p2Count := 0 + for _, r := range result { + if r["provider_id"] == "p1" { + p1Count = r["request_count"].(int) + } + if r["provider_id"] == "p2" { + p2Count = r["request_count"].(int) + } + } + assert.Equal(t, 15, p1Count) + assert.Equal(t, 8, p2Count) +} + +func TestStatsService_Aggregate_ByDate(t *testing.T) { + statsRepo := repository.NewStatsRepository(nil) + svc := NewStatsService(statsRepo) + + stats := []domain.UsageStats{ + {ProviderID: "p1", RequestCount: 10}, + {ProviderID: "p2", RequestCount: 5}, + } + + result := svc.Aggregate(stats, "date") + assert.Len(t, result, 1) + assert.Equal(t, 15, result[0]["request_count"]) +} diff --git a/backend/internal/service/stats_service.go b/backend/internal/service/stats_service.go new file mode 100644 index 0000000..1d7278e --- /dev/null +++ b/backend/internal/service/stats_service.go @@ -0,0 +1,14 @@ +package service + +import ( + "time" + + "nex/backend/internal/domain" +) + +// StatsService 统计服务接口 +type StatsService interface { + Record(providerID, modelName string) error + Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) + Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} +} diff --git a/backend/internal/service/stats_service_impl.go b/backend/internal/service/stats_service_impl.go new file mode 100644 index 0000000..99e2f4d --- /dev/null +++ b/backend/internal/service/stats_service_impl.go @@ -0,0 +1,85 @@ +package service + +import ( + "time" + + "nex/backend/internal/domain" + "nex/backend/internal/repository" +) + +type statsService struct { + statsRepo repository.StatsRepository +} + +func NewStatsService(statsRepo repository.StatsRepository) StatsService { + return &statsService{statsRepo: statsRepo} +} + +func (s *statsService) Record(providerID, modelName string) error { + return s.statsRepo.Record(providerID, modelName) +} + +func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { + return s.statsRepo.Query(providerID, modelName, startDate, endDate) +} + +func (s *statsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} { + switch groupBy { + case "provider": + return s.aggregateByProvider(stats) + case "model": + return s.aggregateByModel(stats) + case "date": + return s.aggregateByDate(stats) + default: + return s.aggregateByProvider(stats) + } +} + +func (s *statsService) aggregateByProvider(stats []domain.UsageStats) []map[string]interface{} { + aggregated := make(map[string]int) + for _, stat := range stats { + aggregated[stat.ProviderID] += stat.RequestCount + } + result := make([]map[string]interface{}, 0, len(aggregated)) + for providerID, count := range aggregated { + result = append(result, map[string]interface{}{ + "provider_id": providerID, + "request_count": count, + }) + } + return result +} + +func (s *statsService) aggregateByModel(stats []domain.UsageStats) []map[string]interface{} { + aggregated := make(map[string]int) + for _, stat := range stats { + key := stat.ProviderID + "/" + stat.ModelName + aggregated[key] += stat.RequestCount + } + result := make([]map[string]interface{}, 0, len(aggregated)) + for key, count := range aggregated { + result = append(result, map[string]interface{}{ + "provider_id": key[:len(key)/2], + "model_name": key[len(key)/2+1:], + "request_count": count, + }) + } + return result +} + +func (s *statsService) aggregateByDate(stats []domain.UsageStats) []map[string]interface{} { + aggregated := make(map[string]int) + for _, stat := range stats { + key := stat.Date.Format("2006-01-02") + aggregated[key] += stat.RequestCount + } + result := make([]map[string]interface{}, 0, len(aggregated)) + for date, count := range aggregated { + result = append(result, map[string]interface{}{ + "date": date, + "request_count": count, + }) + } + return result +} diff --git a/backend/migrations/001_initial_schema.sql b/backend/migrations/001_initial_schema.sql new file mode 100644 index 0000000..5c94dfa --- /dev/null +++ b/backend/migrations/001_initial_schema.sql @@ -0,0 +1,33 @@ +-- +goose Up +CREATE TABLE IF NOT EXISTS providers ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + api_key TEXT NOT NULL, + base_url TEXT NOT NULL, + enabled INTEGER DEFAULT 1, + created_at DATETIME, + updated_at DATETIME +); + +CREATE TABLE IF NOT EXISTS models ( + id TEXT PRIMARY KEY, + provider_id TEXT NOT NULL, + model_name TEXT NOT NULL, + enabled INTEGER DEFAULT 1, + created_at DATETIME, + FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS usage_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider_id TEXT NOT NULL, + model_name TEXT NOT NULL, + request_count INTEGER DEFAULT 0, + date DATE NOT NULL, + UNIQUE(provider_id, model_name, date) +); + +-- +goose Down +DROP TABLE IF EXISTS usage_stats; +DROP TABLE IF EXISTS models; +DROP TABLE IF EXISTS providers; diff --git a/backend/migrations/002_add_indexes.sql b/backend/migrations/002_add_indexes.sql new file mode 100644 index 0000000..a3900ad --- /dev/null +++ b/backend/migrations/002_add_indexes.sql @@ -0,0 +1,9 @@ +-- +goose Up +CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id); +CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name); +CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date); + +-- +goose Down +DROP INDEX IF EXISTS idx_usage_stats_provider_model_date; +DROP INDEX IF EXISTS idx_models_model_name; +DROP INDEX IF EXISTS idx_models_provider_id; diff --git a/backend/pkg/errors/errors.go b/backend/pkg/errors/errors.go new file mode 100644 index 0000000..911981e --- /dev/null +++ b/backend/pkg/errors/errors.go @@ -0,0 +1,74 @@ +package errors + +import ( + "fmt" + "net/http" +) + +// AppError 结构化应用错误 +type AppError struct { + Code string `json:"code"` + Message string `json:"message"` + HTTPStatus int `json:"-"` + Cause error `json:"-"` + Context map[string]interface{} `json:"-"` +} + +// Error implements error interface +func (e *AppError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %s (%v)", e.Code, e.Message, e.Cause) + } + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +// Unwrap returns the underlying error +func (e *AppError) Unwrap() error { + return e.Cause +} + +// NewAppError creates a new AppError +func NewAppError(code, message string, httpStatus int) *AppError { + return &AppError{ + Code: code, + Message: message, + HTTPStatus: httpStatus, + } +} + +// Predefined errors +var ( + ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound) + ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound) + ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound) + ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound) + ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest) + ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError) + ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError) + ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict) +) + +// AsAppError 尝试将 error 转换为 *AppError +func AsAppError(err error) (*AppError, bool) { + if err == nil { + return nil, false + } + var appErr *AppError + if ok := is(err, &appErr); ok { + return appErr, true + } + return nil, false +} + +func is(err error, target interface{}) bool { + // 简单的类型断言 + if e, ok := err.(*AppError); ok { + // 直接赋值 + switch t := target.(type) { + case **AppError: + *t = e + return true + } + } + return false +} diff --git a/backend/pkg/errors/errors_test.go b/backend/pkg/errors/errors_test.go new file mode 100644 index 0000000..edf39bb --- /dev/null +++ b/backend/pkg/errors/errors_test.go @@ -0,0 +1,125 @@ +package errors + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewAppError(t *testing.T) { + err := NewAppError("test_code", "测试消息", http.StatusBadRequest) + assert.Equal(t, "test_code", err.Code) + assert.Equal(t, "测试消息", err.Message) + assert.Equal(t, http.StatusBadRequest, err.HTTPStatus) + assert.Nil(t, err.Cause) + assert.Nil(t, err.Context) +} + +func TestAppError_Error(t *testing.T) { + tests := []struct { + name string + err *AppError + expected string + }{ + { + name: "无原因错误", + err: NewAppError("code1", "消息1", 400), + expected: "code1: 消息1", + }, + { + name: "带原因错误", + err: Wrap(NewAppError("code2", "消息2", 500), errors.New("原始错误")), + expected: "code2: 消息2 (原始错误)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.err.Error()) + }) + } +} + +func TestAppError_Unwrap(t *testing.T) { + cause := errors.New("原始错误") + err := Wrap(ErrInternal, cause) + assert.Equal(t, cause, err.Unwrap()) +} + +func TestWrap(t *testing.T) { + cause := errors.New("网络超时") + wrapped := Wrap(ErrInternal, cause) + assert.Equal(t, "internal_error", wrapped.Code) + assert.Equal(t, "内部错误", wrapped.Message) + assert.Equal(t, http.StatusInternalServerError, wrapped.HTTPStatus) + assert.Equal(t, cause, wrapped.Cause) +} + +func TestWithContext(t *testing.T) { + err := WithContext(ErrModelNotFound, "model", "gpt-4") + assert.Equal(t, "model_not_found", err.Code) + assert.NotNil(t, err.Context) + assert.Equal(t, "gpt-4", err.Context["model"]) + + // 测试链式添加上下文 + err2 := WithContext(err, "provider", "openai") + assert.Equal(t, "gpt-4", err2.Context["model"]) + assert.Equal(t, "openai", err2.Context["provider"]) +} + +func TestWithMessage(t *testing.T) { + err := WithMessage(ErrInvalidRequest, "自定义错误消息") + assert.Equal(t, "invalid_request", err.Code) + assert.Equal(t, "自定义错误消息", err.Message) + assert.Equal(t, http.StatusBadRequest, err.HTTPStatus) +} + +func TestPredefinedErrors(t *testing.T) { + tests := []struct { + name string + err *AppError + code string + httpStatus int + }{ + {"ErrModelNotFound", ErrModelNotFound, "model_not_found", http.StatusNotFound}, + {"ErrModelDisabled", ErrModelDisabled, "model_disabled", http.StatusNotFound}, + {"ErrProviderNotFound", ErrProviderNotFound, "provider_not_found", http.StatusNotFound}, + {"ErrProviderDisabled", ErrProviderDisabled, "provider_disabled", http.StatusNotFound}, + {"ErrInvalidRequest", ErrInvalidRequest, "invalid_request", http.StatusBadRequest}, + {"ErrInternal", ErrInternal, "internal_error", http.StatusInternalServerError}, + {"ErrDatabaseNotInit", ErrDatabaseNotInit, "database_not_initialized", http.StatusInternalServerError}, + {"ErrConflict", ErrConflict, "conflict", http.StatusConflict}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.code, tt.err.Code) + assert.Equal(t, tt.httpStatus, tt.err.HTTPStatus) + }) + } +} + +func TestAsAppError(t *testing.T) { + t.Run("nil输入", func(t *testing.T) { + _, ok := AsAppError(nil) + assert.False(t, ok) + }) + + t.Run("AppError类型", func(t *testing.T) { + appErr, ok := AsAppError(ErrModelNotFound) + assert.True(t, ok) + assert.Equal(t, ErrModelNotFound, appErr) + }) + + t.Run("Wrapped AppError", func(t *testing.T) { + wrapped := Wrap(ErrInternal, errors.New("cause")) + appErr, ok := AsAppError(wrapped) + assert.True(t, ok) + assert.Equal(t, "internal_error", appErr.Code) + }) + + t.Run("非AppError类型", func(t *testing.T) { + _, ok := AsAppError(errors.New("普通错误")) + assert.False(t, ok) + }) +} diff --git a/backend/pkg/errors/wrap.go b/backend/pkg/errors/wrap.go new file mode 100644 index 0000000..859f1d4 --- /dev/null +++ b/backend/pkg/errors/wrap.go @@ -0,0 +1,42 @@ +package errors + +// Wrap wraps an error with cause +func Wrap(err *AppError, cause error) *AppError { + return &AppError{ + Code: err.Code, + Message: err.Message, + HTTPStatus: err.HTTPStatus, + Cause: cause, + } +} + +// WithContext adds context to an AppError +func WithContext(err *AppError, key string, value interface{}) *AppError { + newErr := &AppError{ + Code: err.Code, + Message: err.Message, + HTTPStatus: err.HTTPStatus, + Cause: err.Cause, + } + if err.Context != nil { + newErr.Context = make(map[string]interface{}) + for k, v := range err.Context { + newErr.Context[k] = v + } + } else { + newErr.Context = make(map[string]interface{}) + } + newErr.Context[key] = value + return newErr +} + +// WithMessage creates a new AppError with a custom message +func WithMessage(err *AppError, message string) *AppError { + return &AppError{ + Code: err.Code, + Message: message, + HTTPStatus: err.HTTPStatus, + Cause: err.Cause, + Context: err.Context, + } +} diff --git a/backend/pkg/logger/context.go b/backend/pkg/logger/context.go new file mode 100644 index 0000000..f8a336d --- /dev/null +++ b/backend/pkg/logger/context.go @@ -0,0 +1,17 @@ +package logger + +import "go.uber.org/zap" + +// WithRequestID 向 logger 添加 request_id 字段 +func WithRequestID(logger *zap.Logger, requestID string) *zap.Logger { + return logger.With(zap.String("request_id", requestID)) +} + +// WithContext 向 logger 添加多个自定义字段 +func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger { + zapFields := make([]zap.Field, 0, len(fields)) + for k, v := range fields { + zapFields = append(zapFields, zap.Any(k, v)) + } + return logger.With(zapFields...) +} diff --git a/backend/pkg/logger/logger.go b/backend/pkg/logger/logger.go new file mode 100644 index 0000000..6dec031 --- /dev/null +++ b/backend/pkg/logger/logger.go @@ -0,0 +1,109 @@ +package logger + +import ( + "os" + "path/filepath" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Config 日志配置 +type Config struct { + Level string // 日志级别: debug, info, warn, error + Path string // 日志文件目录,为空则仅输出到 stdout + MaxSize int // 单个日志文件最大尺寸 (MB) + MaxBackups int // 保留的旧日志文件最大数量 + MaxAge int // 保留旧日志文件的最大天数 + Compress bool // 是否压缩旧日志文件 +} + +// New 根据配置创建 zap.Logger +// 如果 Path 为空,仅输出到 stdout; +// 如果 Path 已设置,同时输出到 stdout 和文件(文件使用 JSON 格式,stdout 使用 console 格式) +func New(cfg Config) (*zap.Logger, error) { + level, err := parseLevel(cfg.Level) + if err != nil { + return nil, err + } + + // stdout encoder — console 格式 + stdoutEncoder := zapcore.NewConsoleEncoder(zapcore.EncoderConfig{ + TimeKey: "ts", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + FunctionKey: zapcore.OmitKey, + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalColorLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.StringDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }) + + stdoutCore := zapcore.NewCore( + stdoutEncoder, + zapcore.AddSync(os.Stdout), + level, + ) + + // 仅 stdout 模式 + if cfg.Path == "" { + return zap.New(stdoutCore, zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel)), nil + } + + // 文件 encoder — JSON 格式 + fileEncoder := zapcore.NewJSONEncoder(zapcore.EncoderConfig{ + TimeKey: "ts", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + FunctionKey: zapcore.OmitKey, + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.LowercaseLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.StringDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + }) + + rotateWriter := newRotateWriter(cfg) + fileCore := zapcore.NewCore( + fileEncoder, + zapcore.AddSync(rotateWriter), + level, + ) + + core := zapcore.NewTee(stdoutCore, fileCore) + return zap.New(core, zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel)), nil +} + +// parseLevel 将字符串解析为 zapcore.Level +func parseLevel(s string) (zapcore.Level, error) { + switch s { + case "debug": + return zapcore.DebugLevel, nil + case "info": + return zapcore.InfoLevel, nil + case "warn": + return zapcore.WarnLevel, nil + case "error": + return zapcore.ErrorLevel, nil + default: + return zapcore.InfoLevel, nil + } +} + +// logFileName 生成当日日志文件名: nex-YYYY-MM-DD.log +func logFileName() string { + return "nex-" + time.Now().Format("2006-01-02") + ".log" +} + +// logFilePath 拼接完整日志文件路径 +func logFilePath(dir string) string { + return filepath.Join(dir, logFileName()) +} diff --git a/backend/pkg/logger/logger_test.go b/backend/pkg/logger/logger_test.go new file mode 100644 index 0000000..7590e21 --- /dev/null +++ b/backend/pkg/logger/logger_test.go @@ -0,0 +1,138 @@ +package logger + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestNew_StdoutOnly(t *testing.T) { + logger, err := New(Config{Level: "info"}) + require.NoError(t, err) + require.NotNil(t, logger) + assert.NoError(t, logger.Sync()) +} + +func TestNew_WithFileOutput(t *testing.T) { + dir := filepath.Join(os.TempDir(), "nex-logger-test") + os.MkdirAll(dir, 0755) + defer os.RemoveAll(dir) + + logger, err := New(Config{ + Level: "debug", + Path: dir, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + }) + require.NoError(t, err) + require.NotNil(t, logger) + + logger.Info("test log message") + _ = logger.Sync() + + // 验证日志文件已创建 + files, err := os.ReadDir(dir) + require.NoError(t, err) + assert.NotEmpty(t, files, "日志目录应包含文件") +} + +func TestNew_AllLevels(t *testing.T) { + levels := []string{"debug", "info", "warn", "error"} + for _, level := range levels { + logger, err := New(Config{Level: level}) + assert.NoError(t, err, "级别 %s 应有效", level) + assert.NotNil(t, logger) + assert.NoError(t, logger.Sync()) + } +} + +func TestNew_EmptyLevel(t *testing.T) { + // 空级别应默认为 info + logger, err := New(Config{Level: ""}) + require.NoError(t, err) + require.NotNil(t, logger) + assert.NoError(t, logger.Sync()) +} + +func TestNew_InvalidPath(t *testing.T) { + // 不可写的路径 + logger, err := New(Config{ + Level: "info", + Path: "/nonexistent/deeply/nested/path/logs", + }) + // 应能创建 logger(错误在写入时发生) + // 实际上 lumberjack 会尝试创建目录 + _ = logger + _ = err +} + +func TestParseLevel(t *testing.T) { + tests := []struct { + input string + valid bool + }{ + {"debug", true}, + {"info", true}, + {"warn", true}, + {"error", true}, + {"", true}, // 默认为 info + {"invalid", true}, // 默认为 info + } + for _, tt := range tests { + _, err := parseLevel(tt.input) + assert.NoError(t, err, "parseLevel(%q) 不应报错", tt.input) + } +} + +func TestLogFilePath(t *testing.T) { + result := logFilePath(filepath.Join("var", "log")) + assert.Contains(t, result, "nex-") + assert.Contains(t, result, ".log") +} + +func TestLogFileName(t *testing.T) { + name := logFileName() + assert.Contains(t, name, "nex-") + assert.Contains(t, name, ".log") + assert.Len(t, name, len("nex-2006-01-02.log")) +} + +func TestNewRotateWriter_Defaults(t *testing.T) { + cfg := Config{ + Path: t.TempDir(), + MaxSize: 0, + MaxAge: 0, + Compress: true, + } + writer := newRotateWriter(cfg) + require.NotNil(t, writer) + assert.Equal(t, 100, writer.MaxSize) + assert.Equal(t, 10, writer.MaxBackups) + assert.Equal(t, 30, writer.MaxAge) +} + +func TestWithRequestID(t *testing.T) { + logger, err := New(Config{Level: "info"}) + require.NoError(t, err) + + contextLogger := WithRequestID(logger, "test-request-123") + assert.NotNil(t, contextLogger) + assert.IsType(t, &zap.Logger{}, contextLogger) +} + +func TestWithContext(t *testing.T) { + logger, err := New(Config{Level: "info"}) + require.NoError(t, err) + + contextLogger := WithContext(logger, map[string]interface{}{ + "key1": "value1", + "key2": 42, + }) + assert.NotNil(t, contextLogger) +} diff --git a/backend/pkg/logger/rotate.go b/backend/pkg/logger/rotate.go new file mode 100644 index 0000000..32ee88c --- /dev/null +++ b/backend/pkg/logger/rotate.go @@ -0,0 +1,30 @@ +package logger + +import "gopkg.in/lumberjack.v2" + +// newRotateWriter 根据配置创建 lumberjack.Logger 作为日志轮转写入器 +// 日志文件位于 cfg.Path 目录下,文件名格式为 nex-YYYY-MM-DD.log +func newRotateWriter(cfg Config) *lumberjack.Logger { + maxSize := cfg.MaxSize + if maxSize <= 0 { + maxSize = 100 + } + + maxBackups := cfg.MaxBackups + if maxBackups <= 0 { + maxBackups = 10 + } + + maxAge := cfg.MaxAge + if maxAge <= 0 { + maxAge = 30 + } + + return &lumberjack.Logger{ + Filename: logFilePath(cfg.Path), + MaxSize: maxSize, // MB + MaxBackups: maxBackups, + MaxAge: maxAge, // days + Compress: cfg.Compress, + } +} diff --git a/backend/pkg/validator/validator.go b/backend/pkg/validator/validator.go new file mode 100644 index 0000000..c6f7791 --- /dev/null +++ b/backend/pkg/validator/validator.go @@ -0,0 +1,22 @@ +package validator + +import ( + "github.com/go-playground/validator/v10" +) + +// validate 全局验证器实例 +var validate *validator.Validate + +func init() { + validate = validator.New(validator.WithRequiredStructEnabled()) +} + +// Get 返回全局验证器实例 +func Get() *validator.Validate { + return validate +} + +// Validate 验证结构体 +func Validate(s interface{}) error { + return validate.Struct(s) +} diff --git a/backend/pkg/validator/validator_test.go b/backend/pkg/validator/validator_test.go new file mode 100644 index 0000000..ca27fe4 --- /dev/null +++ b/backend/pkg/validator/validator_test.go @@ -0,0 +1,45 @@ +package validator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type TestStruct struct { + Name string `validate:"required"` + Email string `validate:"required,email"` + Age int `validate:"min=0,max=150"` +} + +func TestValidate_ValidStruct(t *testing.T) { + s := TestStruct{Name: "John", Email: "john@example.com", Age: 25} + err := Validate(s) + assert.NoError(t, err) +} + +func TestValidate_MissingRequired(t *testing.T) { + s := TestStruct{Email: "john@example.com", Age: 25} + err := Validate(s) + assert.Error(t, err) +} + +func TestValidate_InvalidEmail(t *testing.T) { + s := TestStruct{Name: "John", Email: "not-an-email", Age: 25} + err := Validate(s) + assert.Error(t, err) +} + +func TestValidate_AgeOutOfRange(t *testing.T) { + s := TestStruct{Name: "John", Email: "john@example.com", Age: 200} + err := Validate(s) + assert.Error(t, err) +} + +func TestGet_ReturnsInstance(t *testing.T) { + v := Get() + assert.NotNil(t, v) + // 多次调用应返回相同实例 + v2 := Get() + assert.Equal(t, v, v2) +} diff --git a/backend/tests/helpers.go b/backend/tests/helpers.go new file mode 100644 index 0000000..125bc63 --- /dev/null +++ b/backend/tests/helpers.go @@ -0,0 +1,74 @@ +package tests + +import ( + "fmt" + "testing" + + "nex/backend/internal/config" + + "github.com/stretchr/testify/assert" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// SetupTestDB initializes a temporary SQLite database with auto-migration. +func SetupTestDB(t *testing.T) *gorm.DB { + t.Helper() + + dir := t.TempDir() + dsn := dir + "/test.db" + + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + assert.NoError(t, err, "failed to open test database") + + err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) + assert.NoError(t, err, "failed to auto-migrate test database") + + return db +} + +// CleanupTestDB closes the database and removes the temp database file. +func CleanupTestDB(t *testing.T, db *gorm.DB) { + t.Helper() + + sqlDB, err := db.DB() + assert.NoError(t, err, "failed to get underlying sql.DB") + + err = sqlDB.Close() + assert.NoError(t, err, "failed to close test database") +} + +// CreateTestProvider creates a test provider and returns it. +func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider { + t.Helper() + + provider := config.Provider{ + ID: id, + Name: fmt.Sprintf("test-provider-%s", id), + APIKey: fmt.Sprintf("test-api-key-%s", id), + BaseURL: fmt.Sprintf("https://api.test-%s.com", id), + Enabled: true, + } + + err := db.Create(&provider).Error + assert.NoError(t, err, "failed to create test provider") + + return provider +} + +// CreateTestModel creates a test model and returns it. +func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) config.Model { + t.Helper() + + model := config.Model{ + ID: id, + ProviderID: providerID, + ModelName: modelName, + Enabled: true, + } + + err := db.Create(&model).Error + assert.NoError(t, err, "failed to create test model") + + return model +} diff --git a/backend/tests/integration/integration_test.go b/backend/tests/integration/integration_test.go new file mode 100644 index 0000000..ad754d6 --- /dev/null +++ b/backend/tests/integration/integration_test.go @@ -0,0 +1,263 @@ +package integration + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "nex/backend/internal/config" + "nex/backend/internal/domain" + "nex/backend/internal/handler" + "nex/backend/internal/handler/middleware" + "nex/backend/internal/repository" + "nex/backend/internal/service" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) { + t.Helper() + dir := t.TempDir() + db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) + require.NoError(t, err) + err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) + require.NoError(t, err) + t.Cleanup(func() { + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + }) + + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + statsRepo := repository.NewStatsRepository(db) + + providerService := service.NewProviderService(providerRepo) + modelService := service.NewModelService(modelRepo, providerRepo) + _ = service.NewRoutingService(modelRepo, providerRepo) + statsService := service.NewStatsService(statsRepo) + + providerHandler := handler.NewProviderHandler(providerService) + modelHandler := handler.NewModelHandler(modelService) + statsHandler := handler.NewStatsHandler(statsService) + + r := gin.New() + r.Use(middleware.CORS()) + + providers := r.Group("/api/providers") + { + providers.GET("", providerHandler.ListProviders) + providers.POST("", providerHandler.CreateProvider) + providers.GET("/:id", providerHandler.GetProvider) + providers.PUT("/:id", providerHandler.UpdateProvider) + providers.DELETE("/:id", providerHandler.DeleteProvider) + } + + models := r.Group("/api/models") + { + models.GET("", modelHandler.ListModels) + models.POST("", modelHandler.CreateModel) + models.GET("/:id", modelHandler.GetModel) + models.PUT("/:id", modelHandler.UpdateModel) + models.DELETE("/:id", modelHandler.DeleteModel) + } + + stats := r.Group("/api/stats") + { + stats.GET("", statsHandler.GetStats) + stats.GET("/aggregate", statsHandler.AggregateStats) + } + + return r, db +} + +func TestOpenAI_CompleteFlow(t *testing.T) { + r, _ := setupIntegrationTest(t) + + // 1. 创建 Provider + providerBody, _ := json.Marshal(map[string]string{ + "id": "openai", "name": "OpenAI", "api_key": "sk-test-key", "base_url": "https://api.openai.com/v1", + }) + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 201, w.Code) + + // 2. 创建 Model + modelBody, _ := json.Marshal(map[string]string{ + "id": "gpt4", "provider_id": "openai", "model_name": "gpt-4", + }) + w = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 201, w.Code) + + // 3. 列出 Provider + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/api/providers", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + var providers []domain.Provider + json.Unmarshal(w.Body.Bytes(), &providers) + assert.Len(t, providers, 1) + assert.Contains(t, providers[0].APIKey, "***") // 已掩码 + + // 4. 列出 Model + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/api/models?provider_id=openai", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + var models []domain.Model + json.Unmarshal(w.Body.Bytes(), &models) + assert.Len(t, models, 1) + assert.Equal(t, "gpt-4", models[0].ModelName) + + // 5. 更新 Provider + updateBody, _ := json.Marshal(map[string]string{"name": "OpenAI Updated"}) + w = httptest.NewRecorder() + req = httptest.NewRequest("PUT", "/api/providers/openai", bytes.NewReader(updateBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + + // 6. 删除 Model + w = httptest.NewRecorder() + req = httptest.NewRequest("DELETE", "/api/models/gpt4", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 204, w.Code) + + // 7. 删除 Provider + w = httptest.NewRecorder() + req = httptest.NewRequest("DELETE", "/api/providers/openai", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 204, w.Code) +} + +func TestAnthropic_ModelCreation(t *testing.T) { + r, _ := setupIntegrationTest(t) + + // 创建 Provider 和 Model 用于 Anthropic 代理 + providerBody, _ := json.Marshal(map[string]string{ + "id": "anthropic", "name": "Anthropic", "api_key": "sk-ant-test", "base_url": "https://api.anthropic.com/v1", + }) + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 201, w.Code) + + modelBody, _ := json.Marshal(map[string]string{ + "id": "claude3", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229", + }) + w = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 201, w.Code) + + // 验证创建成功 + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/api/models/claude3", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) +} + +func TestStats_RecordingAndQuery(t *testing.T) { + r, db := setupIntegrationTest(t) + + // 创建 Provider 和 Model + providerBody, _ := json.Marshal(map[string]string{ + "id": "p1", "name": "Provider1", "api_key": "key", "base_url": "https://test.com", + }) + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + modelBody, _ := json.Marshal(map[string]string{ + "id": "m1", "provider_id": "p1", "model_name": "gpt-4", + }) + w = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + // 直接通过 repository 记录统计(模拟代理请求后的统计记录) + statsRepo := repository.NewStatsRepository(db) + statsRepo.Record("p1", "gpt-4") + statsRepo.Record("p1", "gpt-4") + statsRepo.Record("p1", "gpt-4") + + // 查询统计 + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/api/stats?provider_id=p1", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + + var stats []domain.UsageStats + json.Unmarshal(w.Body.Bytes(), &stats) + assert.Len(t, stats, 1) + assert.Equal(t, 3, stats[0].RequestCount) + + // 聚合统计 + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) +} + +func TestProvider_DuplicateCreation(t *testing.T) { + r, _ := setupIntegrationTest(t) + + providerBody, _ := json.Marshal(map[string]string{ + "id": "p1", "name": "P1", "api_key": "key", "base_url": "https://test.com", + }) + + // 第一次创建成功 + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 201, w.Code) + + // 第二次创建应失败(UNIQUE 约束) + w = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + assert.Equal(t, 409, w.Code) +} + +func TestProvider_NotFound(t *testing.T) { + r, _ := setupIntegrationTest(t) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/api/providers/nonexistent", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 404, w.Code) +} + +func TestStats_InvalidDate(t *testing.T) { + r, _ := setupIntegrationTest(t) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/api/stats?start_date=not-a-date", nil) + r.ServeHTTP(w, req) + assert.Equal(t, 400, w.Code) +} + +// Suppress unused import warning +var _ = time.Second diff --git a/openspec/specs/anthropic-protocol-proxy/spec.md b/openspec/specs/anthropic-protocol-proxy/spec.md index 86458e3..5c98762 100644 --- a/openspec/specs/anthropic-protocol-proxy/spec.md +++ b/openspec/specs/anthropic-protocol-proxy/spec.md @@ -1,10 +1,6 @@ -# Anthropic 协议代理 +# Anthropic Protocol Proxy -## Purpose - -TBD - 提供 Anthropic Messages API 的代理功能,通过协议转换实现与 OpenAI 兼容供应商的互操作 - -## Requirements +## MODIFIED Requirements ### Requirement: 支持 Anthropic Messages API 端点 @@ -26,6 +22,8 @@ TBD - 提供 Anthropic Messages API 的代理功能,通过协议转换实现 - **THEN** 网关 SHALL 将 OpenAI 流事件转换为 Anthropic 流事件 - **THEN** 网关 SHALL 使用 SSE 格式将转换后的事件流式返回给应用 +**变更说明:** handler 通过 service 层调用,而非直接调用 config 和 provider 包。API 接口保持不变。 + ### Requirement: 将 Anthropic 请求转换为 OpenAI 格式 网关 SHALL 将 Anthropic Messages API 请求转换为 OpenAI Chat Completions API 格式。 @@ -41,138 +39,38 @@ TBD - 提供 Anthropic Messages API 的代理功能,通过协议转换实现 - **THEN** 网关 SHALL 在转换后的 OpenAI 请求中保留这些消息 - **THEN** 网关 SHALL 保留每条消息的 role 和 content -#### Scenario: Tools 转换 +**变更说明:** 协议转换逻辑保持不变,仅调用方式改为通过 service 层。 -- **WHEN** Anthropic 请求包含带有 `input_schema` 的 `tools` -- **THEN** 网关 SHALL 将每个工具转换为 OpenAI 格式,使用 `function.parameters` 替代 `input_schema` -- **THEN** 网关 SHALL 保留工具名称和描述 +## ADDED Requirements -#### Scenario: Tool choice 转换 +### Requirement: 使用 service 层处理请求 -- **WHEN** Anthropic 请求包含 `type: "auto"` 的 `tool_choice` -- **THEN** 网关 SHALL 将其转换为 OpenAI 格式的 `"auto"` -- **WHEN** Anthropic 请求包含 `type: "any"` 的 `tool_choice` -- **THEN** 网关 SHALL 将其转换为 OpenAI 格式的 `"auto"` -- **WHEN** Anthropic 请求包含 `type: "tool"` 和 `name` 的 `tool_choice` -- **THEN** 网关 SHALL 将其转换为 OpenAI 格式的 `{"type": "function", "function": {"name": }}` +Handler SHALL 通过 service 层处理业务逻辑。 -#### Scenario: Tool result 转换 +#### Scenario: 调用 routing service -- **WHEN** Anthropic 请求包含用户消息,其 `content` 数组包含 `type: "tool_result"` 块 -- **THEN** 网关 SHALL 将每个工具结果转换为 `role: "tool"` 的消息 -- **THEN** 网关 SHALL 从 `tool_use_id` 设置 `tool_call_id` -- **THEN** 网关 SHALL 保留 content +- **WHEN** handler 收到请求并转换为 OpenAI 格式 +- **THEN** SHALL 调用 RoutingService.Route() 获取路由结果 +- **THEN** SHALL 使用路由结果中的供应商信息 -#### Scenario: Max tokens 处理 +#### Scenario: 调用 stats service -- **WHEN** Anthropic 请求包含 `max_tokens` -- **THEN** 网关 SHALL 在 OpenAI 请求中包含它作为 `max_tokens` -- **WHEN** Anthropic 请求不包含 `max_tokens` -- **THEN** 网关 SHALL 设置默认值(4096)以满足 Anthropic 的要求 +- **WHEN** 请求成功完成 +- **THEN** SHALL 调用 StatsService.Record() 记录统计 +- **THEN** SHALL 异步记录统计(不阻塞响应) -### Requirement: 将 OpenAI 响应转换为 Anthropic 格式 +### Requirement: 使用结构化错误处理 -网关 SHALL 将 OpenAI Chat Completions API 响应转换为 Anthropic Messages API 格式。 +Handler SHALL 使用结构化错误处理。 -#### Scenario: Content 转换 +#### Scenario: 协议转换错误 -- **WHEN** OpenAI 响应包含 `choices[0].message.content` -- **THEN** 网关 SHALL 将其转换为 Anthropic 格式的 `content: [{"type": "text", "text": }]` +- **WHEN** 协议转换失败 +- **THEN** SHALL 返回结构化错误响应 +- **THEN** SHALL 包含详细的错误信息 -#### Scenario: Tool calls 转换 +#### Scenario: 路由错误处理 -- **WHEN** OpenAI 响应包含 `choices[0].message.tool_calls` -- **THEN** 网关 SHALL 将每个工具调用转换为 `type: "tool_use"` 的内容块 -- **THEN** 网关 SHALL 从 `tool_calls[].id` 设置 `id` -- **THEN** 网关 SHALL 从 `tool_calls[].function.name` 设置 `name` -- **THEN** 网关 SHALL 解析 `arguments` JSON 字符串并将其设置为 `input` 对象 - -#### Scenario: Finish reason 转换 - -- **WHEN** OpenAI 响应的 `finish_reason` 为 `"stop"` -- **THEN** 网关 SHALL 在 Anthropic 响应中设置 `stop_reason: "end_turn"` -- **WHEN** OpenAI 响应的 `finish_reason` 为 `"tool_calls"` -- **THEN** 网关 SHALL 在 Anthropic 响应中设置 `stop_reason: "tool_use"` - -#### Scenario: Usage 转换 - -- **WHEN** OpenAI 响应包含带有 `prompt_tokens` 和 `completion_tokens` 的 `usage` -- **THEN** 网关 SHALL 转换为 Anthropic 格式,使用 `input_tokens` 和 `output_tokens` - -### Requirement: 转换流式事件 - -网关 SHALL 实时将 OpenAI 流事件转换为 Anthropic 流事件。 - -#### Scenario: Message start 事件 - -- **WHEN** 网关开始流式传输 Anthropic 响应 -- **THEN** 网关 SHALL 发送带有消息元数据的 `message_start` 事件 - -#### Scenario: Content block start 事件 - -- **WHEN** OpenAI 流开始返回内容 -- **THEN** 网关 SHALL 发送带有 `type: "text"` 的 `content_block_start` 事件 - -#### Scenario: Content delta 事件 - -- **WHEN** OpenAI 流发送带有内容的 delta -- **THEN** 网关 SHALL 发送带有 `type: "text_delta"` 的 `content_block_delta` 事件,包含文本 - -#### Scenario: Tool use 流式传输 - -- **WHEN** OpenAI 流发送工具调用 delta -- **THEN** 网关 SHALL 缓冲 `arguments` 块 -- **THEN** 网关 SHALL 在工具调用开始时发送带有 `type: "tool_use"` 的 `content_block_start` -- **THEN** 网关 SHALL 发送带有部分 JSON 的 `input_delta` 事件 - -#### Scenario: Content block stop 事件 - -- **WHEN** 内容块完成 -- **THEN** 网关 SHALL 发送 `content_block_stop` 事件 - -#### Scenario: Message stop 事件 - -- **WHEN** OpenAI 流完成 -- **THEN** 网关 SHALL 发送 `message_stop` 事件 - -### Requirement: 支持 Anthropic 特有功能 - -网关 SHALL 支持映射到 OpenAI 能力的 Anthropic 特有功能。 - -#### Scenario: System prompt 作为独立字段 - -- **WHEN** Anthropic 请求包含 `system` 字段 -- **THEN** 网关 SHALL 将其作为 OpenAI 格式的 system 消息处理 - -#### Scenario: 必需的 max_tokens - -- **WHEN** 收到 Anthropic 请求 -- **THEN** 网关 SHALL 确保 `max_tokens` 存在(如果未提供则使用默认值) - -### Requirement: 处理纯文本内容 - -网关 SHALL 在 Anthropic 请求和响应中支持纯文本内容。 - -#### Scenario: 消息中的文本内容 - -- **WHEN** Anthropic 请求在消息中包含文本内容 -- **THEN** 网关 SHALL 正确处理和转发文本内容 - -#### Scenario: 拒绝多模态内容 - -- **WHEN** Anthropic 请求包含多模态内容(图片、文档) -- **THEN** 网关 SHALL 返回错误,指示 MVP 不支持多模态内容 - -### Requirement: 保留请求元数据 - -网关 SHALL 在转换过程中保留请求元数据。 - -#### Scenario: 模型名称保留 - -- **WHEN** Anthropic 请求指定模型名称 -- **THEN** 网关 SHALL 在转换后的 OpenAI 请求中保留模型名称 - -#### Scenario: 自定义参数 - -- **WHEN** Anthropic 请求包含自定义参数(temperature, top_p 等) -- **THEN** 网关 SHALL 在转换后的请求中保留这些参数 +- **WHEN** RoutingService 返回错误 +- **THEN** SHALL 转换为对应的 AppError +- **THEN** SHALL 返回统一的错误响应 diff --git a/openspec/specs/config-management/spec.md b/openspec/specs/config-management/spec.md new file mode 100644 index 0000000..47ecbab --- /dev/null +++ b/openspec/specs/config-management/spec.md @@ -0,0 +1,123 @@ +# Config Management + +## ADDED Requirements + +### Requirement: 使用 YAML 配置文件 + +系统 SHALL 使用 YAML 格式的配置文件。 + +#### Scenario: 配置文件路径 + +- **WHEN** 应用启动 +- **THEN** SHALL 从 `~/.nex/config.yaml` 加载配置 +- **THEN** SHALL 解析 YAML 格式 + +#### Scenario: 配置文件结构 + +- **WHEN** 加载配置文件 +- **THEN** SHALL 包含 server、database、log 等配置节 +- **THEN** SHALL 支持嵌套配置结构 + +### Requirement: 自动生成默认配置 + +系统 SHALL 在首次使用时自动生成默认配置。 + +#### Scenario: 配置文件不存在 + +- **WHEN** 应用启动且 `~/.nex/config.yaml` 不存在 +- **THEN** SHALL 自动创建配置文件 +- **THEN** SHALL 写入默认配置值 +- **THEN** SHALL 记录日志提示已创建 + +#### Scenario: 配置文件已存在 + +- **WHEN** 应用启动且 `~/.nex/config.yaml` 已存在 +- **THEN** SHALL 直接加载配置文件 +- **THEN** SHALL NOT 覆盖现有配置 + +### Requirement: 配置验证 + +系统 SHALL 验证配置的有效性。 + +#### Scenario: 必需字段验证 + +- **WHEN** 加载配置 +- **THEN** SHALL 验证必需字段存在 +- **THEN** SHALL 在字段缺失时返回错误 + +#### Scenario: 字段值验证 + +- **WHEN** 加载配置 +- **THEN** SHALL 验证端口号范围(1-65535) +- **THEN** SHALL 验证日志级别有效性 +- **THEN** SHALL 验证路径有效性 + +#### Scenario: 配置错误处理 + +- **WHEN** 配置验证失败 +- **THEN** SHALL 返回详细的错误信息 +- **THEN** SHALL 指示哪些字段无效 +- **THEN** SHALL 应用 SHALL NOT 启动 + +### Requirement: 配置结构定义 + +系统 SHALL 定义清晰的配置结构。 + +#### Scenario: Server 配置 + +- **WHEN** 加载 server 配置 +- **THEN** SHALL 包含 port、read_timeout、write_timeout 字段 +- **THEN** SHALL 使用合理的默认值 + +#### Scenario: Database 配置 + +- **WHEN** 加载 database 配置 +- **THEN** SHALL 包含 path、max_idle_conns、max_open_conns、conn_max_lifetime 字段 +- **THEN** SHALL 使用合理的默认值 + +#### Scenario: Log 配置 + +- **WHEN** 加载 log 配置 +- **THEN** SHALL 包含 level、path、max_size、max_backups、max_age、compress 字段 +- **THEN** SHALL 使用合理的默认值 + +### Requirement: 默认配置值 + +系统 SHALL 提供合理的默认配置值。 + +#### Scenario: Server 默认值 + +- **WHEN** 使用默认配置 +- **THEN** server.port SHALL 为 9826 +- **THEN** server.read_timeout SHALL 为 30s +- **THEN** server.write_timeout SHALL 为 30s + +#### Scenario: Database 默认值 + +- **WHEN** 使用默认配置 +- **THEN** database.path SHALL 为 `~/.nex/config.db` +- **THEN** database.max_idle_conns SHALL 为 10 +- **THEN** database.max_open_conns SHALL 为 100 +- **THEN** database.conn_max_lifetime SHALL 为 1h + +#### Scenario: Log 默认值 + +- **WHEN** 使用默认配置 +- **THEN** log.level SHALL 为 info +- **THEN** log.path SHALL 为 `~/.nex/log` +- **THEN** log.max_size SHALL 为 100 (MB) +- **THEN** log.max_backups SHALL 为 10 +- **THEN** log.max_age SHALL 为 30 (days) +- **THEN** log.compress SHALL 为 true + +### Requirement: 配置重载支持 + +系统 SHALL 支持配置重载(未来扩展)。 + +#### Scenario: 配置热重载 + +- **WHEN** 配置文件修改(未来功能) +- **THEN** SHALL 支持重新加载配置 +- **THEN** SHALL 应用新配置到可动态调整的参数 + +注:当前版本不支持,仅为未来扩展预留接口。 diff --git a/openspec/specs/database-migration/spec.md b/openspec/specs/database-migration/spec.md new file mode 100644 index 0000000..cc0a50a --- /dev/null +++ b/openspec/specs/database-migration/spec.md @@ -0,0 +1,156 @@ +# Database Migration + +## ADDED Requirements + +### Requirement: 使用 goose 迁移工具 + +系统 SHALL 使用 goose 管理数据库迁移。 + +#### Scenario: goose 安装 + +- **WHEN** 开发环境设置 +- **THEN** SHALL 安装 goose CLI 工具 +- **THEN** SHALL 支持通过 CLI 执行迁移 + +#### Scenario: 迁移文件格式 + +- **WHEN** 创建迁移文件 +- **THEN** SHALL 使用 SQL 格式(.sql 文件) +- **THEN** SHALL 包含 -- +goose Up 和 -- +goose Down 注释 +- **THEN** SHALL 支持事务性迁移 + +### Requirement: 创建初始迁移 + +系统 SHALL 创建初始 schema 迁移。 + +#### Scenario: 初始迁移文件 + +- **WHEN** 创建初始迁移 +- **THEN** SHALL 创建 001_initial_schema.sql +- **THEN** SHALL 包含 providers、models、usage_stats 表的创建语句 +- **THEN** SHALL 包含外键约束 + +#### Scenario: Up 迁移 + +- **WHEN** 执行 up 迁移 +- **THEN** SHALL 创建所有表 +- **THEN** SHALL 创建索引 +- **THEN** SHALL 创建外键约束 + +#### Scenario: Down 迁移 + +- **WHEN** 执行 down 迁移 +- **THEN** SHALL 删除所有表 +- **THEN** SHALL 按正确顺序删除(避免外键约束错误) + +### Requirement: 添加索引迁移 + +系统 SHALL 创建索引迁移。 + +#### Scenario: 索引迁移文件 + +- **WHEN** 创建索引迁移 +- **THEN** SHALL 创建 002_add_indexes.sql +- **THEN** SHALL 为常用查询字段添加索引 + +#### Scenario: 索引定义 + +- **WHEN** 添加索引 +- **THEN** SHALL 为 models(provider_id) 添加索引 +- **THEN** SHALL 为 models(model_name) 添加索引 +- **THEN** SHALL 为 usage_stats(provider_id, model_name, date) 添加复合索引 + +### Requirement: 迁移命令集成 + +迁移 SHALL 集成到 Makefile。 + +#### Scenario: 迁移 up 命令 + +- **WHEN** 执行 `make migrate-up` +- **THEN** SHALL 执行所有待执行的迁移 +- **THEN** SHALL 显示迁移进度 + +#### Scenario: 迁移 down 命令 + +- **WHEN** 执行 `make migrate-down` +- **THEN** SHALL 回滚最后一个迁移 +- **THEN** SHALL 显示回滚进度 + +#### Scenario: 迁移状态命令 + +- **WHEN** 执行 `make migrate-status` +- **THEN** SHALL 显示当前迁移状态 +- **THEN** SHALL 显示已执行和待执行的迁移 + +#### Scenario: 创建迁移命令 + +- **WHEN** 执行 `make migrate-create name=` +- **THEN** SHALL 创建新的迁移文件模板 +- **THEN** SHALL 使用递增的版本号 + +### Requirement: 应用启动时迁移 + +应用 SHALL 在启动时执行迁移。 + +#### Scenario: 自动迁移 + +- **WHEN** 应用启动 +- **THEN** SHALL 自动执行待执行的迁移 +- **THEN** SHALL 在迁移失败时拒绝启动 +- **THEN** SHALL 记录迁移日志 + +#### Scenario: 迁移版本检查 + +- **WHEN** 应用启动 +- **THEN** SHALL 检查数据库迁移版本 +- **THEN** SHALL 在版本不匹配时执行迁移 + +### Requirement: 连接池配置 + +系统 SHALL 配置数据库连接池。 + +#### Scenario: 连接池参数 + +- **WHEN** 初始化数据库连接 +- **THEN** SHALL 设置 MaxIdleConns(默认 10) +- **THEN** SHALL 设置 MaxOpenConns(默认 100) +- **THEN** SHALL 设置 ConnMaxLifetime(默认 1h) + +#### Scenario: 连接池监控 + +- **WHEN** 应用运行 +- **THEN** SHALL 定期记录连接池状态(可选) +- **THEN** SHALL 监控连接池使用情况 + +### Requirement: 迁移回滚支持 + +系统 SHALL 支持迁移回滚。 + +#### Scenario: 回滚到指定版本 + +- **WHEN** 执行 `goose down-to ` +- **THEN** SHALL 回滚到指定版本 +- **THEN** SHALL 按顺序执行 down 迁移 + +#### Scenario: 完全回滚 + +- **WHEN** 执行 `goose reset` +- **THEN** SHALL 回滚所有迁移 +- **THEN** SHALL 清空数据库 + +### Requirement: 迁移文件管理 + +迁移文件 SHALL 版本化管理。 + +#### Scenario: 迁移文件命名 + +- **WHEN** 创建迁移文件 +- **THEN** SHALL 使用格式 `_.sql` +- **THEN** SHALL 版本号递增 +- **THEN** SHALL 名称使用 snake_case + +#### Scenario: 迁移文件存储 + +- **WHEN** 创建迁移文件 +- **THEN** SHALL 存储在 migrations/ 目录 +- **THEN** SHALL 提交到版本控制系统 diff --git a/openspec/specs/error-handling/spec.md b/openspec/specs/error-handling/spec.md new file mode 100644 index 0000000..b569e77 --- /dev/null +++ b/openspec/specs/error-handling/spec.md @@ -0,0 +1,122 @@ +# Error Handling + +## ADDED Requirements + +### Requirement: 定义结构化错误类型 + +系统 SHALL 定义 AppError 结构体。 + +#### Scenario: AppError 结构 + +- **WHEN** 定义错误 +- **THEN** SHALL 包含 Code(错误码)、Message(错误消息)、HTTPStatus(HTTP 状态码)字段 +- **THEN** SHALL 可选包含 Cause(原始错误)、Context(上下文信息)字段 + +#### Scenario: 错误码定义 + +- **WHEN** 定义错误码 +- **THEN** SHALL 使用 kebab-case 格式(如 model_not_found) +- **THEN** SHALL 定义清晰的错误码语义 +- **THEN** SHALL 为常见错误预定义错误码 + +### Requirement: 预定义常见错误 + +系统 SHALL 预定义常见错误。 + +#### Scenario: 资源不存在错误 + +- **WHEN** 资源不存在 +- **THEN** SHALL 使用 ErrModelNotFound、ErrProviderNotFound 等预定义错误 +- **THEN** SHALL 设置 HTTP 状态码为 404 + +#### Scenario: 验证错误 + +- **WHEN** 请求验证失败 +- **THEN** SHALL 使用 ErrInvalidRequest 等预定义错误 +- **THEN** SHALL 设置 HTTP 状态码为 400 + +#### Scenario: 内部错误 + +- **WHEN** 发生内部错误 +- **THEN** SHALL 使用 ErrInternal 等预定义错误 +- **THEN** SHALL 设置 HTTP 状态码为 500 + +### Requirement: 支持错误包装 + +系统 SHALL 支持错误包装。 + +#### Scenario: 包装原始错误 + +- **WHEN** 发生错误 +- **THEN** SHALL 能够包装原始错误(设置 Cause 字段) +- **THEN** SHALL 能够添加上下文信息(设置 Context 字段) +- **THEN** SHALL 保留错误链 + +#### Scenario: 错误链追踪 + +- **WHEN** 记录错误日志 +- **THEN** SHALL 记录完整的错误链 +- **THEN** SHALL 包含每一层错误的信息 + +### Requirement: 统一错误响应 + +系统 SHALL 统一错误响应格式。 + +#### Scenario: OpenAI 协议错误响应 + +- **WHEN** OpenAI 协议发生错误 +- **THEN** SHALL 返回标准 OpenAI 错误响应格式 +- **THEN** SHALL 包含 error.message、error.type、error.code 字段 + +#### Scenario: Anthropic 协议错误响应 + +- **WHEN** Anthropic 协议发生错误 +- **THEN** SHALL 返回标准 Anthropic 错误响应格式 +- **THEN** SHALL 包含 type、error.type、error.message 字段 + +#### Scenario: 管理 API 错误响应 + +- **WHEN** 管理 API 发生错误 +- **THEN** SHALL 返回统一的错误响应格式 +- **THEN** SHALL 包含 code、message 字段 +- **THEN** SHALL 可选包含 details 字段(验证错误详情) + +### Requirement: 错误处理中间件 + +系统 SHALL 提供错误处理中间件。 + +#### Scenario: 捕获 panic + +- **WHEN** handler 发生 panic +- **THEN** 中间件 SHALL 捕获 panic +- **THEN** SHALL 记录堆栈信息 +- **THEN** SHALL 返回 500 错误响应 + +#### Scenario: 统一错误响应 + +- **WHEN** handler 返回 AppError +- **THEN** 中间件 SHALL 转换为对应的 HTTP 响应 +- **THEN** SHALL 设置正确的 HTTP 状态码 +- **THEN** SHALL 设置正确的响应体 + +### Requirement: 替换现有错误处理 + +系统 SHALL 替换所有 errors.New 为结构化错误。 + +#### Scenario: config 包错误 + +- **WHEN** config 包发生错误 +- **THEN** SHALL 使用结构化错误 +- **THEN** SHALL 设置适当的错误码和 HTTP 状态码 + +#### Scenario: service 层错误 + +- **WHEN** service 层发生错误 +- **THEN** SHALL 使用结构化错误 +- **THEN** SHALL 包含业务上下文信息 + +#### Scenario: repository 层错误 + +- **WHEN** repository 层发生错误 +- **THEN** SHALL 包装数据库错误 +- **THEN** SHALL 转换为应用错误 diff --git a/openspec/specs/layered-architecture/spec.md b/openspec/specs/layered-architecture/spec.md new file mode 100644 index 0000000..3cea38f --- /dev/null +++ b/openspec/specs/layered-architecture/spec.md @@ -0,0 +1,115 @@ +# Layered Architecture + +## ADDED Requirements + +### Requirement: 实现三层架构 + +系统 SHALL 实现 handler → service → repository 三层架构。 + +#### Scenario: Handler 层职责 + +- **WHEN** 处理 HTTP 请求 +- **THEN** handler 层 SHALL 仅负责 HTTP 请求解析和响应 +- **THEN** handler 层 SHALL 调用 service 层处理业务逻辑 +- **THEN** handler 层 SHALL NOT 直接访问数据库 + +#### Scenario: Service 层职责 + +- **WHEN** 处理业务逻辑 +- **THEN** service 层 SHALL 包含业务规则和验证 +- **THEN** service 层 SHALL 调用 repository 层访问数据 +- **THEN** service 层 SHALL 协调多个 repository 的操作 + +#### Scenario: Repository 层职责 + +- **WHEN** 访问数据 +- **THEN** repository 层 SHALL 仅负责数据访问 +- **THEN** repository 层 SHALL 封装数据库操作 +- **THEN** repository 层 SHALL NOT 包含业务逻辑 + +### Requirement: 定义核心接口 + +系统 SHALL 定义清晰的接口边界。 + +#### Scenario: Service 接口定义 + +- **WHEN** 定义 service 接口 +- **THEN** SHALL 定义 ProviderService、ModelService、RoutingService、StatsService 接口 +- **THEN** SHALL 定义清晰的业务方法签名 +- **THEN** SHALL 使用 domain 类型作为参数和返回值 + +#### Scenario: Repository 接口定义 + +- **WHEN** 定义 repository 接口 +- **THEN** SHALL 定义 ProviderRepository、ModelRepository、StatsRepository 接口 +- **THEN** SHALL 定义清晰的数据访问方法签名 +- **THEN** SHALL 使用 domain 类型作为参数和返回值 + +#### Scenario: Provider Client 接口定义 + +- **WHEN** 定义 provider client 接口 +- **THEN** SHALL 定义 ProviderClient 接口 +- **THEN** SHALL 包含 SendRequest 和 SendStreamRequest 方法 +- **THEN** SHALL 支持接口 Mock + +### Requirement: 实现依赖注入 + +系统 SHALL 使用手动依赖注入。 + +#### Scenario: Repository 注入 + +- **WHEN** 初始化 service +- **THEN** SHALL 通过构造函数注入 repository 依赖 +- **THEN** SHALL 使用接口类型而非具体类型 + +#### Scenario: Service 注入 + +- **WHEN** 初始化 handler +- **THEN** SHALL 通过构造函数注入 service 依赖 +- **THEN** SHALL 使用接口类型而非具体类型 + +#### Scenario: 主函数组装 + +- **WHEN** 应用启动 +- **THEN** main.go SHALL 按顺序构造所有依赖 +- **THEN** SHALL 先构造基础设施(logger、database) +- **THEN** SHALL 再构造 repository、service、handler + +### Requirement: 定义 Domain 模型 + +系统 SHALL 定义独立的 domain 模型。 + +#### Scenario: Domain 模型定义 + +- **WHEN** 定义领域模型 +- **THEN** SHALL 在 internal/domain/ 包中定义 +- **THEN** SHALL 包含 Provider、Model、UsageStats 等模型 +- **THEN** SHALL 与数据库模型分离 + +#### Scenario: Domain 模型使用 + +- **WHEN** service 和 repository 处理数据 +- **THEN** SHALL 使用 domain 模型 +- **THEN** SHALL NOT 使用数据库模型(GORM 模型) + +### Requirement: 提高可测试性 + +架构 SHALL 提高代码可测试性。 + +#### Scenario: Service 层测试 + +- **WHEN** 测试 service 层 +- **THEN** SHALL 能够 Mock repository 依赖 +- **THEN** SHALL 能够独立测试业务逻辑 + +#### Scenario: Handler 层测试 + +- **WHEN** 测试 handler 层 +- **THEN** SHALL 能够 Mock service 依赖 +- **THEN** SHALL 能够独立测试 HTTP 处理逻辑 + +#### Scenario: Repository 层测试 + +- **WHEN** 测试 repository 层 +- **THEN** SHALL 使用测试数据库 +- **THEN** SHALL 能够独立测试数据访问逻辑 diff --git a/openspec/specs/middleware-system/spec.md b/openspec/specs/middleware-system/spec.md new file mode 100644 index 0000000..4d8a6e6 --- /dev/null +++ b/openspec/specs/middleware-system/spec.md @@ -0,0 +1,136 @@ +# Middleware System + +## ADDED Requirements + +### Requirement: 实现请求 ID 中间件 + +系统 SHALL 实现请求 ID 中间件。 + +#### Scenario: 生成请求 ID + +- **WHEN** 收到 HTTP 请求且 header 中无 X-Request-ID +- **THEN** SHALL 生成新的 UUID 作为请求 ID +- **THEN** SHALL 设置到响应 header 的 X-Request-ID +- **THEN** SHALL 设置到 gin.Context 中 + +#### Scenario: 复用请求 ID + +- **WHEN** 收到 HTTP 请求且 header 中已有 X-Request-ID +- **THEN** SHALL 复用该请求 ID +- **THEN** SHALL 设置到响应 header +- **THEN** SHALL 设置到 gin.Context 中 + +### Requirement: 实现日志中间件 + +系统 SHALL 实现日志中间件。 + +#### Scenario: 记录请求开始 + +- **WHEN** 收到 HTTP 请求 +- **THEN** SHALL 记录请求开始日志 +- **THEN** SHALL 包含请求方法、路径、客户端 IP、请求 ID + +#### Scenario: 记录请求结束 + +- **WHEN** HTTP 请求处理完成 +- **THEN** SHALL 记录请求结束日志 +- **THEN** SHALL 包含响应状态码、响应大小、请求耗时、请求 ID + +#### Scenario: 记录错误 + +- **WHEN** 请求处理过程中发生错误 +- **THEN** SHALL 记录错误日志 +- **THEN** SHALL 包含错误详情和请求 ID + +### Requirement: 实现错误恢复中间件 + +系统 SHALL 实现错误恢复中间件。 + +#### Scenario: 捕获 panic + +- **WHEN** handler 发生 panic +- **THEN** SHALL 捕获 panic +- **THEN** SHALL 记录堆栈信息 +- **THEN** SHALL 返回 500 错误响应 + +#### Scenario: 记录堆栈 + +- **WHEN** 发生 panic +- **THEN** SHALL 记录完整的堆栈信息 +- **THEN** SHALL 包含 panic 原因和请求 ID + +#### Scenario: 防止服务崩溃 + +- **WHEN** handler panic +- **THEN** SHALL 恢复并继续处理其他请求 +- **THEN** SHALL NOT 导致服务崩溃 + +### Requirement: 实现 CORS 中间件 + +系统 SHALL 实现 CORS 中间件。 + +#### Scenario: 允许所有来源 + +- **WHEN** 收到 CORS 预检请求 +- **THEN** SHALL 设置 Access-Control-Allow-Origin 为 * +- **THEN** SHALL 设置 Access-Control-Allow-Methods +- **THEN** SHALL 设置 Access-Control-Allow-Headers + +#### Scenario: 处理预检请求 + +- **WHEN** 收到 OPTIONS 请求 +- **THEN** SHALL 返回 204 状态码 +- **THEN** SHALL 设置 CORS headers + +注:当前配置允许所有来源,适合个人使用。 + +### Requirement: 中间件注册顺序 + +系统 SHALL 按正确顺序注册中间件。 + +#### Scenario: 全局中间件顺序 + +- **WHEN** 注册全局中间件 +- **THEN** SHALL 按以下顺序注册: + 1. RequestID(生成请求 ID) + 2. Recovery(错误恢复) + 3. Logging(日志记录) + 4. CORS(跨域处理) + +#### Scenario: 中间件执行顺序 + +- **WHEN** 处理请求 +- **THEN** SHALL 按注册顺序执行中间件 +- **THEN** SHALL 确保请求 ID 在其他中间件之前生成 + +### Requirement: 中间件配置 + +中间件 SHALL 支持配置。 + +#### Scenario: 日志中间件配置 + +- **WHEN** 初始化日志中间件 +- **THEN** SHALL 注入 logger 实例 +- **THEN** SHALL 使用配置的日志级别 + +#### Scenario: Recovery 中间件配置 + +- **WHEN** 初始化 recovery 中间件 +- **THEN** SHALL 注入 logger 实例 +- **THEN** SHALL 配置堆栈打印深度 + +### Requirement: 中间件上下文传递 + +中间件 SHALL 支持上下文传递。 + +#### Scenario: 请求 ID 传递 + +- **WHEN** 中间件设置请求 ID +- **THEN** SHALL 通过 gin.Context 传递 +- **THEN** SHALL 在后续中间件和 handler 中可访问 + +#### Scenario: 日志上下文传递 + +- **WHEN** 日志中间件记录日志 +- **THEN** SHALL 包含请求 ID +- **THEN** SHALL 支持添加其他上下文信息 diff --git a/openspec/specs/model-management/spec.md b/openspec/specs/model-management/spec.md index d7f8655..8b55c86 100644 --- a/openspec/specs/model-management/spec.md +++ b/openspec/specs/model-management/spec.md @@ -1,10 +1,6 @@ -# 模型管理 +# Model Management -## Purpose - -TBD - 提供模型配置的管理功能,模型关联到供应商 - -## Requirements +## MODIFIED Requirements ### Requirement: 创建模型配置 @@ -23,16 +19,7 @@ TBD - 提供模型配置的管理功能,模型关联到供应商 - **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) - **THEN** 错误 SHALL 指示供应商不存在 -#### Scenario: 使用重复 ID 创建模型 - -- **WHEN** 向 `/api/models` 发送 POST 请求,携带已存在的 ID -- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict) - -#### Scenario: 创建模型时缺少必需字段 - -- **WHEN** 向 `/api/models` 发送 POST 请求,缺少必需字段(id, provider_id 或 model_name) -- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) -- **THEN** 错误 SHALL 指示缺少哪些字段 +**变更说明:** handler 通过 ModelService 调用,数据访问通过 ModelRepository 和 ProviderRepository。API 接口保持不变。 ### Requirement: 列出所有模型 @@ -44,10 +31,7 @@ TBD - 提供模型配置的管理功能,模型关联到供应商 - **THEN** 网关 SHALL 返回所有模型的列表 - **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, enabled, created_at -#### Scenario: 列出模型时为空 - -- **WHEN** 向 `/api/models` 发送 GET 请求,且不存在模型 -- **THEN** 网关 SHALL 返回空列表 +**变更说明:** 数据访问从 config 包迁移到 ModelRepository。API 接口保持不变。 ### Requirement: 按供应商列出模型 @@ -58,24 +42,7 @@ TBD - 提供模型配置的管理功能,模型关联到供应商 - **WHEN** 向 `/api/models?provider_id=` 发送 GET 请求 - **THEN** 网关 SHALL 返回指定供应商的模型列表 -#### Scenario: 列出不存在供应商的模型 - -- **WHEN** 向 `/api/models?provider_id=` 发送 GET 请求 -- **THEN** 网关 SHALL 返回空列表 - -### Requirement: 获取特定模型 - -网关 SHALL 允许通过 ID 获取特定模型。 - -#### Scenario: 获取存在的模型 - -- **WHEN** 向 `/api/models/:id` 发送 GET 请求,携带有效的模型 ID -- **THEN** 网关 SHALL 返回模型详情 - -#### Scenario: 获取不存在的模型 - -- **WHEN** 向 `/api/models/:id` 发送 GET 请求,携带不存在的 ID -- **THEN** 网关 SHALL 返回错误,状态码为 404 (Not Found) +**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。 ### Requirement: 更新模型配置 @@ -87,22 +54,13 @@ TBD - 提供模型配置的管理功能,模型关联到供应商 - **THEN** 网关 SHALL 更新数据库中的模型记录 - **THEN** 网关 SHALL 返回更新后的模型 -#### Scenario: 更新不存在的模型 - -- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带不存在的 ID -- **THEN** 网关 SHALL 返回错误,状态码为 404 (Not Found) - #### Scenario: 更新模型供应商 - **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带新的 provider_id - **THEN** 网关 SHALL 验证新供应商是否存在 - **THEN** 网关 SHALL 更新模型的供应商关联 -#### Scenario: 部分更新 - -- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,仅包含部分字段 -- **THEN** 网关 SHALL 仅更新提供的字段 -- **THEN** 网关 SHALL 保留未更改的字段 +**变更说明:** 通过 ModelService、ModelRepository 和 ProviderRepository 实现。API 接口保持不变。 ### Requirement: 删除模型配置 @@ -114,61 +72,38 @@ TBD - 提供模型配置的管理功能,模型关联到供应商 - **THEN** 网关 SHALL 删除模型记录 - **THEN** 网关 SHALL 返回状态码 204 (No Content) -#### Scenario: 删除不存在的模型 +**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。 -- **WHEN** 向 `/api/models/:id` 发送 DELETE 请求,携带不存在的 ID -- **THEN** 网关 SHALL 返回错误,状态码为 404 (Not Found) +## ADDED Requirements -### Requirement: 启用和禁用模型 +### Requirement: 使用 service 层处理业务逻辑 -网关 SHALL 支持启用和禁用模型。 +Handler SHALL 通过 ModelService 处理业务逻辑。 -#### Scenario: 禁用模型 +#### Scenario: 调用 service 方法 -- **WHEN** 模型的 `enabled` 字段设置为 false -- **THEN** 网关 SHALL 不向该模型路由请求 -- **THEN** 模型 SHALL 保留在数据库中 +- **WHEN** handler 收到请求 +- **THEN** SHALL 调用对应的 ModelService 方法(Create、Get、List、Update、Delete) +- **THEN** SHALL 使用 domain.Model 类型 -#### Scenario: 启用模型 - -- **WHEN** 已禁用模型的 `enabled` 字段设置为 true -- **THEN** 网关 SHALL 恢复向该模型路由请求 - -### Requirement: 验证模型配置 - -网关 SHALL 验证模型配置数据。 - -#### Scenario: 验证供应商存在 - -- **WHEN** 创建或更新模型时携带 provider_id -- **THEN** 网关 SHALL 验证供应商存在于数据库中 - -#### Scenario: 验证必需字段 +#### Scenario: 供应商验证 - **WHEN** 创建或更新模型 -- **THEN** 网关 SHALL 验证 id, provider_id 和 model_name 存在且非空 +- **THEN** SHALL 在 service 层验证供应商存在 +- **THEN** SHALL 通过 ProviderRepository 查询供应商 -### Requirement: 支持透明的模型名称 +### Requirement: 使用 repository 层访问数据 -网关 SHALL 使用模型名称透明传输,不做转换。 +Service SHALL 通过 ModelRepository 访问数据。 -#### Scenario: 模型名称保留 +#### Scenario: 调用 repository 方法 -- **WHEN** 模型配置了 model_name -- **THEN** 网关 SHALL 在路由请求时使用该确切名称 -- **THEN** 网关 SHALL 不修改或转换模型名称 +- **WHEN** service 处理业务逻辑 +- **THEN** SHALL 调用对应的 ModelRepository 方法 +- **THEN** SHALL 使用 domain.Model 类型 -#### Scenario: 不同供应商的同名模型 +#### Scenario: 数据验证 -- **WHEN** 多个供应商拥有相同 model_name 的模型 -- **THEN** 每个模型 SHALL 通过其唯一 ID 和 provider_id 区分 -- **THEN** 网关 SHALL 基于模型名称和供应商关联的组合进行路由 - -### Requirement: 随供应商级联删除 - -网关 SHALL 在删除关联供应商时删除模型。 - -#### Scenario: 供应商删除级联到模型 - -- **WHEN** 供应商被删除 -- **THEN** 该供应商关联的所有模型 SHALL 自动删除 +- **WHEN** 创建或更新模型 +- **THEN** SHALL 在 service 层验证业务规则 +- **THEN** SHALL 在 repository 层执行数据库操作 diff --git a/openspec/specs/openai-protocol-proxy/spec.md b/openspec/specs/openai-protocol-proxy/spec.md index 9a030d4..141eecc 100644 --- a/openspec/specs/openai-protocol-proxy/spec.md +++ b/openspec/specs/openai-protocol-proxy/spec.md @@ -1,10 +1,6 @@ -# OpenAI 协议代理 +# OpenAI Protocol Proxy -## Purpose - -TBD - 提供 OpenAI Chat Completions API 的代理功能 - -## Requirements +## MODIFIED Requirements ### Requirement: 支持 OpenAI Chat Completions API 端点 @@ -23,27 +19,7 @@ TBD - 提供 OpenAI Chat Completions API 的代理功能 - **THEN** 网关 SHALL 使用 SSE 格式将响应流式返回给应用 - **THEN** 网关 SHALL 在流完成时发送 `data: [DONE]` -### Requirement: 支持 Function Calling - -网关 SHALL 在非流式和流式模式下都支持 OpenAI Function Calling。 - -#### Scenario: 非流式函数调用 - -- **WHEN** 应用发送包含 `tools` 定义的请求 -- **AND** 供应商返回包含 `tool_calls` 的响应 -- **THEN** 网关 SHALL 在响应中原样转发 `tool_calls` - -#### Scenario: 流式函数调用 - -- **WHEN** 应用发送包含 `tools` 定义的流式请求 -- **AND** 供应商在 delta 块中流式返回 `tool_calls` -- **THEN** 网关 SHALL 将 `tool_calls` 块流式发送给应用 -- **THEN** 网关 SHALL 在完成时设置 `finish_reason: "tool_calls"` - -#### Scenario: 工具结果提交 - -- **WHEN** 应用发送包含 `role: "tool"` 消息的后续请求,携带函数结果 -- **THEN** 网关 SHALL 将工具结果原样转发给供应商 +**变更说明:** handler 通过 service 层调用,而非直接调用 config 和 provider 包。API 接口保持不变。 ### Requirement: 根据模型名称路由请求 @@ -65,6 +41,8 @@ TBD - 提供 OpenAI Chat Completions API 的代理功能 - **WHEN** 请求包含已禁用模型的 `model` 字段 - **THEN** 网关 SHALL 返回错误响应,指示模型不可用 +**变更说明:** 路由逻辑从 router 包迁移到 RoutingService,通过 service 层调用。API 接口保持不变。 + ### Requirement: 对 OpenAI 兼容供应商透明代理 网关 SHALL 对 OpenAI 兼容供应商的请求和响应进行透明转发,不做修改。 @@ -83,47 +61,38 @@ TBD - 提供 OpenAI Chat Completions API 的代理功能 - **THEN** 网关 SHALL 将响应体原样返回给应用 - **THEN** 网关 SHALL 保留所有响应头和状态码 -### Requirement: 处理供应商错误 +**变更说明:** provider client 通过接口注入到 handler,便于测试和替换实现。API 接口保持不变。 -网关 SHALL 将供应商错误透明返回给应用。 +## ADDED Requirements -#### Scenario: 供应商返回错误 +### Requirement: 使用 service 层处理请求 -- **WHEN** 供应商返回错误响应(4xx 或 5xx) -- **THEN** 网关 SHALL 将相同的错误响应返回给应用 -- **THEN** 网关 SHALL 保留错误消息和状态码 +Handler SHALL 通过 service 层处理业务逻辑。 -#### Scenario: 供应商超时 +#### Scenario: 调用 routing service -- **WHEN** 供应商在超时时间内未响应 -- **THEN** 网关 SHALL 向应用返回超时错误 +- **WHEN** handler 收到请求 +- **THEN** SHALL 调用 RoutingService.Route() 获取路由结果 +- **THEN** SHALL 使用路由结果中的供应商信息 -#### Scenario: 供应商连接失败 +#### Scenario: 调用 stats service -- **WHEN** 网关无法连接到供应商 -- **THEN** 网关 SHALL 向应用返回连接错误 +- **WHEN** 请求成功完成 +- **THEN** SHALL 调用 StatsService.Record() 记录统计 +- **THEN** SHALL 异步记录统计(不阻塞响应) -### Requirement: 支持标准 OpenAI 请求字段 +### Requirement: 使用结构化错误处理 -网关 SHALL 支持所有标准 OpenAI Chat Completions API 请求字段。 +Handler SHALL 使用结构化错误处理。 -#### Scenario: 支持标准字段 +#### Scenario: 路由错误处理 -- **WHEN** 请求包含标准字段(model, messages, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, stop, n, stream, tools, tool_choice, user) -- **THEN** 网关 SHALL 接受并将所有字段转发给供应商 +- **WHEN** RoutingService 返回错误 +- **THEN** SHALL 转换为对应的 AppError +- **THEN** SHALL 返回统一的错误响应 -### Requirement: 维护流式连接稳定性 +#### Scenario: 供应商错误处理 -网关 SHALL 维护稳定的流式连接并优雅处理中断。 - -#### Scenario: 流中断 - -- **WHEN** 供应商流在传输过程中中断 -- **THEN** 网关 SHALL 优雅关闭客户端连接 -- **THEN** 网关 SHALL 记录中断日志以便调试 - -#### Scenario: 客户端提前断开 - -- **WHEN** 客户端在流完成前断开连接 -- **THEN** 网关 SHALL 取消供应商请求 -- **THEN** 网关 SHALL 释放相关资源 +- **WHEN** ProviderClient 返回错误 +- **THEN** SHALL 包装为 AppError +- **THEN** SHALL 包含请求上下文信息 diff --git a/openspec/specs/provider-management/spec.md b/openspec/specs/provider-management/spec.md index c5228a7..f46def7 100644 --- a/openspec/specs/provider-management/spec.md +++ b/openspec/specs/provider-management/spec.md @@ -1,10 +1,6 @@ -# 供应商管理 +# Provider Management -## Purpose - -TBD - 提供供应商配置的管理功能(创建、查询、更新、删除) - -## Requirements +## MODIFIED Requirements ### Requirement: 创建供应商配置 @@ -28,6 +24,8 @@ TBD - 提供供应商配置的管理功能(创建、查询、更新、删除 - **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) - **THEN** 错误 SHALL 指示缺少哪些字段 +**变更说明:** handler 通过 ProviderService 调用,数据访问通过 ProviderRepository。API 接口保持不变。 + ### Requirement: 列出所有供应商 网关 SHALL 允许获取所有供应商配置。 @@ -39,10 +37,7 @@ TBD - 提供供应商配置的管理功能(创建、查询、更新、删除 - **THEN** 每个供应商 SHALL 包含 id, name, api_key(已掩码), base_url, enabled, created_at, updated_at - **THEN** api_key SHALL 被掩码(仅显示最后 4 个字符) -#### Scenario: 列出供应商时为空 - -- **WHEN** 向 `/api/providers` 发送 GET 请求,且不存在供应商 -- **THEN** 网关 SHALL 返回空列表 +**变更说明:** 数据访问从 config 包迁移到 ProviderRepository。API 接口保持不变。 ### Requirement: 获取特定供应商 @@ -59,6 +54,8 @@ TBD - 提供供应商配置的管理功能(创建、查询、更新、删除 - **WHEN** 向 `/api/providers/:id` 发送 GET 请求,携带不存在的 ID - **THEN** 网关 SHALL 返回错误,状态码为 404 (Not Found) +**变更说明:** 通过 ProviderService 和 ProviderRepository 实现。API 接口保持不变。 + ### Requirement: 更新供应商配置 网关 SHALL 允许更新现有供应商配置。 @@ -70,16 +67,7 @@ TBD - 提供供应商配置的管理功能(创建、查询、更新、删除 - **THEN** 网关 SHALL 返回更新后的供应商 - **THEN** updated_at 时间戳 SHALL 被更新 -#### Scenario: 更新不存在的供应商 - -- **WHEN** 向 `/api/providers/:id` 发送 PUT 请求,携带不存在的 ID -- **THEN** 网关 SHALL 返回错误,状态码为 404 (Not Found) - -#### Scenario: 部分更新 - -- **WHEN** 向 `/api/providers/:id` 发送 PUT 请求,仅包含部分字段 -- **THEN** 网关 SHALL 仅更新提供的字段 -- **THEN** 网关 SHALL 保留未更改的字段 +**变更说明:** 通过 ProviderService 和 ProviderRepository 实现。API 接口保持不变。 ### Requirement: 删除供应商配置 @@ -92,60 +80,38 @@ TBD - 提供供应商配置的管理功能(创建、查询、更新、删除 - **THEN** 网关 SHALL 删除所有关联的模型(CASCADE) - **THEN** 网关 SHALL 返回状态码 204 (No Content) -#### Scenario: 删除不存在的供应商 +**变更说明:** 通过 ProviderService 和 ProviderRepository 实现。API 接口保持不变。 -- **WHEN** 向 `/api/providers/:id` 发送 DELETE 请求,携带不存在的 ID -- **THEN** 网关 SHALL 返回错误,状态码为 404 (Not Found) +## ADDED Requirements -### Requirement: 启用和禁用供应商 +### Requirement: 使用 service 层处理业务逻辑 -网关 SHALL 支持启用和禁用供应商。 +Handler SHALL 通过 ProviderService 处理业务逻辑。 -#### Scenario: 禁用供应商 +#### Scenario: 调用 service 方法 -- **WHEN** 供应商的 `enabled` 字段设置为 false -- **THEN** 网关 SHALL 不向该供应商路由请求 -- **THEN** 供应商 SHALL 保留在数据库中 +- **WHEN** handler 收到请求 +- **THEN** SHALL 调用对应的 ProviderService 方法(Create、Get、List、Update、Delete) +- **THEN** SHALL 使用 domain.Provider 类型 -#### Scenario: 启用供应商 +#### Scenario: 错误处理 -- **WHEN** 已禁用供应商的 `enabled` 字段设置为 true -- **THEN** 网关 SHALL 恢复向该供应商路由请求 +- **WHEN** service 返回错误 +- **THEN** SHALL 转换为 HTTP 错误响应 +- **THEN** SHALL 使用结构化错误处理 -### Requirement: 验证供应商配置 +### Requirement: 使用 repository 层访问数据 -网关 SHALL 验证供应商配置数据。 +Service SHALL 通过 ProviderRepository 访问数据。 -#### Scenario: 验证 base_url 格式 +#### Scenario: 调用 repository 方法 -- **WHEN** 创建或更新供应商时使用无效的 base_url 格式 -- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) +- **WHEN** service 处理业务逻辑 +- **THEN** SHALL 调用对应的 ProviderRepository 方法 +- **THEN** SHALL 使用 domain.Provider 类型 -#### Scenario: 验证必需字段 +#### Scenario: 数据验证 - **WHEN** 创建或更新供应商 -- **THEN** 网关 SHALL 验证 id, name, api_key 和 base_url 存在且非空 - -### Requirement: 安全存储供应商配置 - -网关 SHALL 安全存储供应商 API Key。 - -#### Scenario: 存储 API Key - -- **WHEN** 创建或更新供应商时携带 API Key -- **THEN** 网关 SHALL 将 API Key 存储在数据库中 - -#### Scenario: 在响应中掩码 API Key - -- **WHEN** 在 API 响应中返回供应商数据 -- **THEN** API Key SHALL 被掩码(仅显示最后 4 个字符) - -### Requirement: 仅支持 OpenAI 兼容供应商 - -网关 SHALL 在 MVP 中仅支持 OpenAI 兼容供应商。 - -#### Scenario: 供应商类型验证 - -- **WHEN** 创建供应商 -- **THEN** 供应商类型 SHALL 隐式设置为 "openai-compatible" -- **THEN** MVP 中 SHALL 不支持其他供应商类型 +- **THEN** SHALL 在 service 层验证业务规则 +- **THEN** SHALL 在 repository 层执行数据库操作 diff --git a/openspec/specs/request-validation/spec.md b/openspec/specs/request-validation/spec.md new file mode 100644 index 0000000..444f911 --- /dev/null +++ b/openspec/specs/request-validation/spec.md @@ -0,0 +1,132 @@ +# Request Validation + +## ADDED Requirements + +### Requirement: 使用 validator 库 + +系统 SHALL 使用 go-playground/validator 进行请求验证。 + +#### Scenario: 验证器初始化 + +- **WHEN** 应用启动 +- **THEN** SHALL 初始化 validator 实例 +- **THEN** SHALL 注册自定义验证规则 + +#### Scenario: 验证规则定义 + +- **WHEN** 定义请求结构体 +- **THEN** SHALL 使用 struct tag 定义验证规则 +- **THEN** SHALL 支持必需字段、范围、格式等验证 + +### Requirement: 验证 OpenAI 请求 + +系统 SHALL 验证 OpenAI ChatCompletionRequest。 + +#### Scenario: 必需字段验证 + +- **WHEN** 收到 OpenAI 请求 +- **THEN** SHALL 验证 model 字段不为空 +- **THEN** SHALL 验证 messages 字段不为空且至少有一条消息 + +#### Scenario: 参数范围验证 + +- **WHEN** 收到 OpenAI 请求 +- **THEN** SHALL 验证 temperature 范围在 [0, 2] +- **THEN** SHALL 验证 max_tokens 大于 0 +- **THEN** SHALL 验证 top_p 范围在 (0, 1] +- **THEN** SHALL 验证 frequency_penalty 范围在 [-2, 2] +- **THEN** SHALL 验证 presence_penalty 范围在 [-2, 2] + +#### Scenario: 消息内容验证 + +- **WHEN** 验证 messages 字段 +- **THEN** SHALL 验证每条消息的 role 有效(system、user、assistant、tool) +- **THEN** SHALL 验证 content 不为空 + +### Requirement: 验证 Anthropic 请求 + +系统 SHALL 验证 Anthropic MessagesRequest。 + +#### Scenario: 必需字段验证 + +- **WHEN** 收到 Anthropic 请求 +- **THEN** SHALL 验证 model 字段不为空 +- **THEN** SHALL 验证 messages 字段不为空且至少有一条消息 +- **THEN** SHALL 验证 max_tokens 大于 0(或使用默认值) + +#### Scenario: 参数范围验证 + +- **WHEN** 收到 Anthropic 请求 +- **THEN** SHALL 验证 temperature 范围在 [0, 1] +- **THEN** SHALL 验证 top_p 范围在 (0, 1] + +#### Scenario: 消息内容验证 + +- **WHEN** 验证 messages 字段 +- **THEN** SHALL 验证每条消息的 role 有效(user、assistant) +- **THEN** SHALL 验证 content 数组不为空 + +### Requirement: 验证管理 API 请求 + +系统 SHALL 验证管理 API 请求。 + +#### Scenario: Provider 创建验证 + +- **WHEN** 创建 Provider +- **THEN** SHALL 验证 id、name、api_key、base_url 字段不为空 +- **THEN** SHALL 验证 base_url 格式有效(URL 格式) +- **THEN** SHALL 验证 id 格式(字母、数字、下划线、连字符) + +#### Scenario: Model 创建验证 + +- **WHEN** 创建 Model +- **THEN** SHALL 验证 id、provider_id、model_name 字段不为空 +- **THEN** SHALL 验证 provider_id 存在 + +#### Scenario: 更新请求验证 + +- **WHEN** 更新资源 +- **THEN** SHALL 验证至少提供一个可更新字段 +- **THEN** SHALL 验证字段值有效性 + +### Requirement: 返回友好的验证错误 + +系统 SHALL 返回友好的验证错误响应。 + +#### Scenario: 错误消息格式 + +- **WHEN** 验证失败 +- **THEN** SHALL 返回 400 状态码 +- **THEN** SHALL 返回详细的错误消息 +- **THEN** SHALL 指示哪些字段验证失败 + +#### Scenario: 多字段错误 + +- **WHEN** 多个字段验证失败 +- **THEN** SHALL 返回所有验证错误 +- **THEN** SHALL 使用结构化格式(字段名 → 错误消息) + +#### Scenario: 国际化支持 + +- **WHEN** 返回验证错误(未来) +- **THEN** SHALL 支持错误消息国际化 +- **THEN** SHALL 使用错误码作为国际化 key + +注:当前版本使用中文错误消息。 + +### Requirement: 在 handler 中应用验证 + +系统 SHALL 在 handler 中应用验证。 + +#### Scenario: 验证中间件 + +- **WHEN** 使用验证中间件 +- **THEN** SHALL 在请求解析后立即验证 +- **THEN** SHALL 在验证失败时提前返回错误 +- **THEN** SHALL 避免执行后续处理逻辑 + +#### Scenario: 验证时机 + +- **WHEN** 处理请求 +- **THEN** SHALL 在 handler 函数开始时验证 +- **THEN** SHALL 在验证通过后才执行业务逻辑 diff --git a/openspec/specs/structured-logging/spec.md b/openspec/specs/structured-logging/spec.md new file mode 100644 index 0000000..fd4c216 --- /dev/null +++ b/openspec/specs/structured-logging/spec.md @@ -0,0 +1,124 @@ +# Structured Logging + +## ADDED Requirements + +### Requirement: 使用 zap 结构化日志 + +系统 SHALL 使用 zap 作为结构化日志库。 + +#### Scenario: 日志初始化 + +- **WHEN** 应用启动 +- **THEN** SHALL 初始化 zap logger +- **THEN** SHALL 根据配置设置日志级别 +- **THEN** SHALL 配置日志输出格式为 JSON + +#### Scenario: 日志字段 + +- **WHEN** 记录日志 +- **THEN** SHALL 支持结构化字段(key-value) +- **THEN** SHALL 支持嵌套字段 +- **THEN** SHALL 自动包含时间戳和日志级别 + +### Requirement: 支持日志滚动 + +系统 SHALL 支持日志文件滚动,使用 lumberjack。 + +#### Scenario: 按大小滚动 + +- **WHEN** 日志文件大小达到配置的最大值(默认 100 MB) +- **THEN** SHALL 创建新的日志文件 +- **THEN** SHALL 重命名旧文件(添加序号) + +#### Scenario: 按数量清理 + +- **WHEN** 日志文件数量超过配置的最大备份数(默认 10 个) +- **THEN** SHALL 删除最旧的日志文件 + +#### Scenario: 按时间清理 + +- **WHEN** 日志文件超过配置的最大保留天数(默认 30 天) +- **THEN** SHALL 自动删除过期文件 + +#### Scenario: 压缩旧文件 + +- **WHEN** 配置启用压缩(默认启用) +- **THEN** SHALL 压缩旧的日志文件为 .gz 格式 + +### Requirement: 支持请求 ID 追踪 + +系统 SHALL 支持请求 ID 追踪。 + +#### Scenario: 生成请求 ID + +- **WHEN** 收到 HTTP 请求 +- **THEN** SHALL 生成唯一的请求 ID(UUID) +- **THEN** SHALL 设置到响应 header 中(X-Request-ID) +- **THEN** SHALL 添加到日志上下文中 + +#### Scenario: 复用请求 ID + +- **WHEN** 请求 header 中已包含 X-Request-ID +- **THEN** SHALL 复用该请求 ID +- **THEN** SHALL 在整个请求生命周期中使用该 ID + +#### Scenario: 日志关联请求 ID + +- **WHEN** 记录请求相关的日志 +- **THEN** SHALL 自动包含请求 ID 字段 +- **THEN** SHALL 支持通过请求 ID 检索日志 + +### Requirement: 记录请求日志 + +系统 SHALL 记录 HTTP 请求日志。 + +#### Scenario: 请求开始日志 + +- **WHEN** 收到 HTTP 请求 +- **THEN** SHALL 记录请求方法、路径、客户端 IP +- **THEN** SHALL 包含请求 ID + +#### Scenario: 请求结束日志 + +- **WHEN** HTTP 请求处理完成 +- **THEN** SHALL 记录响应状态码、响应大小 +- **THEN** SHALL 记录请求耗时 +- **THEN** SHALL 包含请求 ID + +### Requirement: 支持日志级别 + +系统 SHALL 支持日志级别控制。 + +#### Scenario: 日志级别配置 + +- **WHEN** 配置日志级别 +- **THEN** SHALL 支持 debug、info、warn、error 级别 +- **THEN** SHALL 只记录大于等于配置级别的日志 + +#### Scenario: 开发环境日志 + +- **WHEN** 配置为开发模式 +- **THEN** SHALL 使用 debug 级别 +- **THEN** SHALL 输出到控制台和文件 + +#### Scenario: 生产环境日志 + +- **WHEN** 配置为生产模式 +- **THEN** SHALL 使用 info 级别 +- **THEN** SHALL 仅输出到文件 + +### Requirement: 日志存储位置 + +日志 SHALL 存储在 `~/.nex/log/` 目录。 + +#### Scenario: 日志文件路径 + +- **WHEN** 初始化日志系统 +- **THEN** SHALL 使用 `~/.nex/log/` 作为日志目录 +- **THEN** SHALL 自动创建目录(如果不存在) + +#### Scenario: 日志文件命名 + +- **WHEN** 创建日志文件 +- **THEN** SHALL 使用 `nex-YYYY-MM-DD.log` 格式命名 +- **THEN** SHALL 按日期创建新文件 diff --git a/openspec/specs/test-coverage/spec.md b/openspec/specs/test-coverage/spec.md new file mode 100644 index 0000000..09cf130 --- /dev/null +++ b/openspec/specs/test-coverage/spec.md @@ -0,0 +1,106 @@ +# Test Coverage + +## ADDED Requirements + +### Requirement: 建立单元测试体系 + +系统 SHALL 建立完整的单元测试体系,覆盖核心业务逻辑。 + +#### Scenario: config 包测试覆盖 + +- **WHEN** 运行 config 包的单元测试 +- **THEN** SHALL 覆盖 Provider、Model、Stats 的 CRUD 操作 +- **THEN** SHALL 测试正常场景和错误场景 +- **THEN** SHALL 验证数据库操作的准确性 + +#### Scenario: router 包测试覆盖 + +- **WHEN** 运行 router 包的单元测试 +- **THEN** SHALL 覆盖模型路由逻辑 +- **THEN** SHALL 测试模型不存在、模型禁用、供应商禁用等场景 +- **THEN** SHALL 验证路由结果的正确性 + +#### Scenario: protocol 包测试覆盖 + +- **WHEN** 运行 protocol 包的单元测试 +- **THEN** SHALL 覆盖 OpenAI 和 Anthropic 协议转换逻辑 +- **THEN** SHALL 测试请求转换、响应转换、流式转换 +- **THEN** SHALL 验证转换的准确性和完整性 + +### Requirement: 建立集成测试体系 + +系统 SHALL 建立集成测试体系,覆盖 API 端到端流程。 + +#### Scenario: OpenAI 协议集成测试 + +- **WHEN** 运行 OpenAI 协议的集成测试 +- **THEN** SHALL 测试完整的请求-响应流程 +- **THEN** SHALL 测试流式响应流程 +- **THEN** SHALL 测试错误处理流程 + +#### Scenario: Anthropic 协议集成测试 + +- **WHEN** 运行 Anthropic 协议的集成测试 +- **THEN** SHALL 测试完整的请求-响应流程 +- **THEN** SHALL 测试流式响应流程 +- **THEN** SHALL 测试协议转换的准确性 + +#### Scenario: 管理接口集成测试 + +- **WHEN** 运行管理接口的集成测试 +- **THEN** SHALL 测试 Provider、Model、Stats 的 CRUD 操作 +- **THEN** SHALL 验证 API 响应格式 +- **THEN** SHALL 测试错误场景 + +### Requirement: 提供测试工具函数 + +系统 SHALL 提供测试工具函数,简化测试编写。 + +#### Scenario: 测试数据库初始化 + +- **WHEN** 编写需要数据库的测试 +- **THEN** SHALL 提供测试数据库初始化函数 +- **THEN** SHALL 使用临时数据库文件 +- **THEN** SHALL 在测试结束后自动清理 + +#### Scenario: Mock 工具 + +- **WHEN** 编写需要 Mock 的测试 +- **THEN** SHALL 提供 Mock 接口实现 +- **THEN** SHALL 支持常见 Mock 场景 +- **THEN** SHALL 易于使用和扩展 + +### Requirement: 达到测试覆盖率目标 + +系统 SHALL 达到 > 80% 的测试覆盖率。 + +#### Scenario: 总体覆盖率 + +- **WHEN** 运行所有测试并生成覆盖率报告 +- **THEN** 总体覆盖率 SHALL 大于 80% +- **THEN** 核心包覆盖率 SHALL 大于 85% + +#### Scenario: 覆盖率报告生成 + +- **WHEN** 运行测试覆盖率命令 +- **THEN** SHALL 生成覆盖率报告文件 +- **THEN** SHALL 支持生成 HTML 格式报告 +- **THEN** SHALL 显示每个文件的覆盖率 + +### Requirement: 集成到构建流程 + +测试 SHALL 集成到构建流程中。 + +#### Scenario: 运行测试命令 + +- **WHEN** 执行 `make test` 命令 +- **THEN** SHALL 运行所有单元测试和集成测试 +- **THEN** SHALL 显示测试结果 +- **THEN** SHALL 在测试失败时返回非零退出码 + +#### Scenario: 覆盖率检查命令 + +- **WHEN** 执行 `make test-coverage` 命令 +- **THEN** SHALL 运行测试并生成覆盖率报告 +- **THEN** SHALL 检查覆盖率是否达标 +- **THEN** SHALL 在覆盖率不足时返回非零退出码 diff --git a/openspec/specs/usage-statistics/spec.md b/openspec/specs/usage-statistics/spec.md index 2bb8181..040283b 100644 --- a/openspec/specs/usage-statistics/spec.md +++ b/openspec/specs/usage-statistics/spec.md @@ -1,10 +1,6 @@ -# 用量统计 +# Usage Statistics -## Purpose - -TBD - 提供请求用量统计的记录和查询功能 - -## Requirements +## MODIFIED Requirements ### Requirement: 记录请求统计 @@ -22,15 +18,7 @@ TBD - 提供请求用量统计的记录和查询功能 - **THEN** 网关 SHALL 增加该供应商和模型的请求计数 - **THEN** 网关 SHALL 在流结束后记录统计 -#### Scenario: 不记录失败请求 - -- **WHEN** 请求在到达供应商前失败(路由错误、验证错误) -- **THEN** 网关 SHALL NOT 增加请求计数 - -#### Scenario: 记录供应商错误 - -- **WHEN** 请求到达供应商但供应商返回错误 -- **THEN** 网关 SHALL 仍然增加请求计数(请求已被处理) +**变更说明:** 统计记录通过 StatsService 调用,数据访问通过 StatsRepository。API 接口保持不变。 ### Requirement: 按供应商查询统计 @@ -41,10 +29,7 @@ TBD - 提供请求用量统计的记录和查询功能 - **WHEN** 向 `/api/stats?provider_id=` 发送 GET 请求 - **THEN** 网关 SHALL 仅返回指定供应商的统计 -#### Scenario: 查询不存在供应商的统计 - -- **WHEN** 向 `/api/stats?provider_id=` 发送 GET 请求 -- **THEN** 网关 SHALL 返回空结果或零计数 +**变更说明:** 通过 StatsService 和 StatsRepository 实现。API 接口保持不变。 ### Requirement: 按模型查询统计 @@ -55,10 +40,7 @@ TBD - 提供请求用量统计的记录和查询功能 - **WHEN** 向 `/api/stats?model_name=` 发送 GET 请求 - **THEN** 网关 SHALL 仅返回指定模型的统计 -#### Scenario: 查询不存在模型的统计 - -- **WHEN** 向 `/api/stats?model_name=` 发送 GET 请求 -- **THEN** 网关 SHALL 返回空结果或零计数 +**变更说明:** 通过 StatsService 和 StatsRepository 实现。API 接口保持不变。 ### Requirement: 按日期范围查询统计 @@ -70,20 +52,7 @@ TBD - 提供请求用量统计的记录和查询功能 - **THEN** 网关 SHALL 仅返回指定范围内的日期统计 - **THEN** 日期格式 SHALL 为 YYYY-MM-DD -#### Scenario: 不使用日期范围查询统计 - -- **WHEN** 向 `/api/stats` 发送 GET 请求,不带 start 和 end 参数 -- **THEN** 网关 SHALL 返回所有可用日期的统计 - -#### Scenario: 仅使用开始日期查询统计 - -- **WHEN** 向 `/api/stats?start=` 发送 GET 请求 -- **THEN** 网关 SHALL 返回从开始日期到当前日期的统计 - -#### Scenario: 仅使用结束日期查询统计 - -- **WHEN** 向 `/api/stats?end=` 发送 GET 请求 -- **THEN** 网关 SHALL 返回从最早可用日期到结束日期的统计 +**变更说明:** 通过 StatsService 和 StatsRepository 实现。API 接口保持不变。 ### Requirement: 聚合统计 @@ -95,25 +64,7 @@ TBD - 提供请求用量统计的记录和查询功能 - **THEN** 网关 SHALL 为该天维护单条统计记录 - **THEN** 请求计数 SHALL 为所有请求的总和 -#### Scenario: 跨多天请求 - -- **WHEN** 跨不同天发起请求 -- **THEN** 网关 SHALL 为每一天维护独立的统计记录 - -### Requirement: 以结构化格式返回统计 - -网关 SHALL 以结构化 JSON 格式返回统计。 - -#### Scenario: 统计响应格式 - -- **WHEN** 查询统计 -- **THEN** 响应 SHALL 为统计对象数组 -- **THEN** 每个对象 SHALL 包含 provider_id, model_name, request_count 和 date - -#### Scenario: 空统计 - -- **WHEN** 没有统计匹配查询条件 -- **THEN** 网关 SHALL 返回空数组 +**变更说明:** 聚合逻辑在 StatsRepository 中实现。API 接口保持不变。 ### Requirement: 支持并发统计记录 @@ -125,31 +76,48 @@ TBD - 提供请求用量统计的记录和查询功能 - **THEN** 网关 SHALL 正确为每个请求增加请求计数 - **THEN** 不 SHALL 因并发写入而丢失统计 -### Requirement: 仅将统计限制为请求计数 +**变更说明:** 并发控制在 StatsRepository 中通过数据库事务实现。API 接口保持不变。 -网关 SHALL 在 MVP 中仅记录请求计数,不记录其他指标。 +## ADDED Requirements -#### Scenario: 仅请求计数 +### Requirement: 使用 service 层处理业务逻辑 + +Handler SHALL 通过 StatsService 处理业务逻辑。 + +#### Scenario: 调用 service 方法 + +- **WHEN** handler 收到统计查询请求 +- **THEN** SHALL 调用对应的 StatsService 方法(Get、Aggregate) +- **THEN** SHALL 使用 domain.UsageStats 类型 + +#### Scenario: 异步记录统计 + +- **WHEN** 请求完成需要记录统计 +- **THEN** SHALL 异步调用 StatsService.Record() +- **THEN** SHALL 不阻塞响应返回 + +### Requirement: 使用 repository 层访问数据 + +Service SHALL 通过 StatsRepository 访问数据。 + +#### Scenario: 调用 repository 方法 + +- **WHEN** service 处理业务逻辑 +- **THEN** SHALL 调用对应的 StatsRepository 方法 +- **THEN** SHALL 使用 domain.UsageStats 类型 + +#### Scenario: 事务处理 - **WHEN** 记录统计 -- **THEN** 网关 SHALL 仅跟踪请求数量 -- **THEN** 网关 SHALL NOT 在 MVP 中跟踪 token 使用、成本、延迟或其他指标 +- **THEN** SHALL 在 repository 层使用数据库事务 +- **THEN** SHALL 确保并发安全 -### Requirement: 为新组合初始化统计 +### Requirement: 统计查询优化 -网关 SHALL 为新的供应商-模型-日期组合自动创建统计记录。 +统计查询 SHALL 使用索引优化性能。 -#### Scenario: 组合的首次请求 +#### Scenario: 使用索引 -- **WHEN** 在新日期首次对供应商-模型组合发起请求 -- **THEN** 网关 SHALL 创建新的统计记录,request_count = 1 - -### Requirement: 查询所有统计 - -网关 SHALL 允许不带过滤条件查询所有统计。 - -#### Scenario: 查询所有统计 - -- **WHEN** 向 `/api/stats` 发送 GET 请求,不带任何查询参数 -- **THEN** 网关 SHALL 返回所有可用统计 -- **THEN** 结果 SHALL 按日期排序(最近的在前) +- **WHEN** 查询统计 +- **THEN** SHALL 使用 (provider_id, model_name, date) 复合索引 +- **THEN** SHALL 优化查询性能