feat: 实现分层架构,包含 domain、service、repository 和 pkg 层
- 新增 domain 层:model、provider、route、stats 实体 - 新增 service 层:models、providers、routing、stats 业务逻辑 - 新增 repository 层:models、providers、stats 数据访问 - 新增 pkg 工具包:errors、logger、validator - 新增中间件:CORS、logging、recovery、request ID - 新增数据库迁移:初始 schema 和索引 - 新增单元测试和集成测试 - 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage - 移除 config 子包和 model_router(已迁移至分层架构)
This commit is contained in:
45
backend/Makefile
Normal file
45
backend/Makefile
Normal file
@@ -0,0 +1,45 @@
|
||||
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint
|
||||
|
||||
# 构建
|
||||
build:
|
||||
go build -o bin/server ./cmd/server
|
||||
|
||||
# 运行
|
||||
run:
|
||||
go run ./cmd/server
|
||||
|
||||
# 测试
|
||||
test:
|
||||
go test ./... -v
|
||||
|
||||
# 测试覆盖率
|
||||
test-coverage:
|
||||
go test ./... -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: coverage.html"
|
||||
|
||||
# 清理
|
||||
clean:
|
||||
rm -rf bin/ coverage.out coverage.html
|
||||
|
||||
# 数据库迁移
|
||||
migrate-up:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) up
|
||||
|
||||
migrate-down:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) down
|
||||
|
||||
migrate-status:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) status
|
||||
|
||||
migrate-create:
|
||||
@read -p "Migration name: " name; \
|
||||
goose -dir migrations create $$name sql
|
||||
|
||||
# 代码检查
|
||||
lint:
|
||||
golangci-lint run ./...
|
||||
|
||||
# 安装依赖
|
||||
deps:
|
||||
go mod tidy
|
||||
@@ -10,13 +10,21 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
- 支持 Function Calling / Tools
|
||||
- 多供应商配置和路由
|
||||
- 用量统计
|
||||
- 结构化日志(zap + lumberjack)
|
||||
- YAML 配置管理
|
||||
- 请求验证
|
||||
- 中间件支持(请求 ID、日志、恢复、CORS)
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **语言**: Go
|
||||
- **语言**: Go 1.26+
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
- **日志**: zap + lumberjack
|
||||
- **配置**: gopkg.in/yaml.v3
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
## 项目结构
|
||||
|
||||
@@ -24,37 +32,86 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
backend/
|
||||
├── cmd/
|
||||
│ └── server/
|
||||
│ └── main.go # 主程序入口
|
||||
│ └── main.go # 主程序入口(依赖注入)
|
||||
├── internal/
|
||||
│ ├── config/ # 配置和数据库
|
||||
│ │ ├── config.go # 配置目录管理
|
||||
│ │ ├── database.go # 数据库连接
|
||||
│ │ ├── models.go # 数据模型
|
||||
│ │ ├── provider.go # 供应商 CRUD
|
||||
│ │ ├── model.go # 模型 CRUD
|
||||
│ │ └── stats.go # 统计记录
|
||||
│ ├── handler/ # HTTP 处理器
|
||||
│ ├── config/ # 配置管理
|
||||
│ │ ├── config.go # 配置加载/保存/验证
|
||||
│ │ └── models.go # GORM 数据模型
|
||||
│ ├── domain/ # 领域模型
|
||||
│ │ ├── provider.go
|
||||
│ │ ├── model.go
|
||||
│ │ ├── stats.go
|
||||
│ │ └── route.go
|
||||
│ ├── handler/ # HTTP 处理器
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ │ ├── request_id.go
|
||||
│ │ │ ├── logging.go
|
||||
│ │ │ ├── recovery.go
|
||||
│ │ │ └── cors.go
|
||||
│ │ ├── openai_handler.go
|
||||
│ │ ├── anthropic_handler.go
|
||||
│ │ ├── provider_handler.go
|
||||
│ │ ├── model_handler.go
|
||||
│ │ └── stats_handler.go
|
||||
│ ├── protocol/ # 协议适配器
|
||||
│ ├── protocol/ # 协议适配器
|
||||
│ │ ├── openai/
|
||||
│ │ │ ├── types.go
|
||||
│ │ │ └── adapter.go
|
||||
│ │ │ ├── types.go # 请求/响应类型 + 验证
|
||||
│ │ │ └── adapter.go # OpenAI 协议适配
|
||||
│ │ └── anthropic/
|
||||
│ │ ├── types.go
|
||||
│ │ ├── converter.go
|
||||
│ │ └── stream_converter.go
|
||||
│ ├── provider/ # 供应商客户端
|
||||
│ │ ├── types.go # 请求/响应类型 + 验证
|
||||
│ │ ├── converter.go # 协议转换
|
||||
│ │ └── stream_converter.go # 流式转换
|
||||
│ ├── provider/ # 供应商客户端
|
||||
│ │ └── client.go
|
||||
│ └── router/ # 模型路由
|
||||
│ └── model_router.go
|
||||
│ ├── repository/ # 数据访问层
|
||||
│ │ ├── provider_repo.go # 接口定义
|
||||
│ │ ├── provider_repo_impl.go
|
||||
│ │ ├── model_repo.go
|
||||
│ │ ├── model_repo_impl.go
|
||||
│ │ ├── stats_repo.go
|
||||
│ │ └── stats_repo_impl.go
|
||||
│ └── service/ # 业务逻辑层
|
||||
│ ├── provider_service.go # 接口定义
|
||||
│ ├── provider_service_impl.go
|
||||
│ ├── model_service.go
|
||||
│ ├── model_service_impl.go
|
||||
│ ├── routing_service.go
|
||||
│ ├── routing_service_impl.go
|
||||
│ ├── stats_service.go
|
||||
│ └── stats_service_impl.go
|
||||
├── pkg/ # 公共包
|
||||
│ ├── errors/ # 结构化错误
|
||||
│ │ ├── errors.go
|
||||
│ │ └── wrap.go
|
||||
│ ├── logger/ # 日志系统
|
||||
│ │ ├── logger.go
|
||||
│ │ ├── rotate.go
|
||||
│ │ └── context.go
|
||||
│ └── validator/ # 验证器
|
||||
│ └── validator.go
|
||||
├── migrations/ # 数据库迁移
|
||||
│ ├── 001_initial_schema.sql
|
||||
│ └── 002_add_indexes.sql
|
||||
├── tests/ # 测试
|
||||
│ ├── helpers.go
|
||||
│ ├── integration/
|
||||
│ ├── unit/
|
||||
│ └── testdata/
|
||||
├── Makefile
|
||||
├── go.mod
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## 架构
|
||||
|
||||
采用三层架构(handler → service → repository),通过依赖注入连接:
|
||||
|
||||
```
|
||||
handler(HTTP 请求处理)
|
||||
→ service(业务逻辑)
|
||||
→ repository(数据访问)
|
||||
```
|
||||
|
||||
## 运行方式
|
||||
|
||||
### 安装依赖
|
||||
@@ -69,7 +126,59 @@ go mod download
|
||||
go run cmd/server/main.go
|
||||
```
|
||||
|
||||
服务将在端口 9826 启动。
|
||||
服务将在端口 9826 启动。首次启动会自动创建配置文件和运行数据库迁移。
|
||||
|
||||
## 配置
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成。
|
||||
|
||||
```yaml
|
||||
server:
|
||||
port: 9826
|
||||
read_timeout: 30s
|
||||
write_timeout: 30s
|
||||
|
||||
database:
|
||||
path: ~/.nex/config.db
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
conn_max_lifetime: 1h
|
||||
|
||||
log:
|
||||
level: info
|
||||
path: ~/.nex/log
|
||||
max_size: 100 # MB
|
||||
max_backups: 10
|
||||
max_age: 30 # 天
|
||||
compress: true
|
||||
```
|
||||
|
||||
数据文件:
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
make test
|
||||
|
||||
# 生成覆盖率报告
|
||||
make test-coverage
|
||||
```
|
||||
|
||||
## 数据库迁移
|
||||
|
||||
```bash
|
||||
# 使用 Makefile
|
||||
make migrate-up DB_PATH=~/.nex/config.db
|
||||
make migrate-down DB_PATH=~/.nex/config.db
|
||||
make migrate-status DB_PATH=~/.nex/config.db
|
||||
|
||||
# 或直接使用 goose
|
||||
goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
```
|
||||
|
||||
## API 文档
|
||||
|
||||
@@ -169,20 +278,20 @@ POST /v1/messages
|
||||
- `end_date` - 结束日期(YYYY-MM-DD)
|
||||
- `group_by` - 聚合维度(provider/model/date)
|
||||
|
||||
## 配置
|
||||
|
||||
配置和数据存储在 `~/.nex/` 目录:
|
||||
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
|
||||
## 开发
|
||||
|
||||
### 构建
|
||||
|
||||
```bash
|
||||
go build -o ai-gateway cmd/server/main.go
|
||||
make build
|
||||
```
|
||||
|
||||
### 代码检查
|
||||
|
||||
```bash
|
||||
make lint
|
||||
```
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Go 1.21 或更高版本
|
||||
- Go 1.26 或更高版本
|
||||
|
||||
@@ -2,88 +2,217 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 初始化数据库
|
||||
if err := config.InitDB(); err != nil {
|
||||
log.Fatalf("初始化数据库失败: %v", err)
|
||||
// 1. 加载配置
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
log.Fatalf("配置验证失败: %v", err)
|
||||
}
|
||||
defer config.CloseDB()
|
||||
|
||||
// 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.Default()
|
||||
|
||||
// 配置 CORS
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
// 2. 初始化日志
|
||||
zapLogger, err := pkgLogger.New(pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
MaxBackups: cfg.Log.MaxBackups,
|
||||
MaxAge: cfg.Log.MaxAge,
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化日志失败: %v", err)
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
|
||||
// 3. 初始化数据库
|
||||
db, err := initDatabase(cfg)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
|
||||
}
|
||||
defer closeDB(db)
|
||||
|
||||
// 4. 初始化 repository 层
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
// 5. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
|
||||
// 6. 初始化 provider client
|
||||
providerClient := provider.NewClient()
|
||||
|
||||
// 7. 初始化 handler 层
|
||||
openaiHandler := handler.NewOpenAIHandler(providerClient, routingService, statsService)
|
||||
anthropicHandler := handler.NewAnthropicHandler(providerClient, routingService, statsService)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
// 8. 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
// 注册中间件(按正确顺序)
|
||||
r.Use(middleware.RequestID())
|
||||
r.Use(middleware.Recovery(zapLogger))
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
setupRoutes(r)
|
||||
setupRoutes(r, openaiHandler, anthropicHandler, providerHandler, modelHandler, statsHandler)
|
||||
|
||||
// 创建 HTTP 服务器
|
||||
// 9. 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: ":9826",
|
||||
Handler: r,
|
||||
Addr: formatAddr(cfg.Server.Port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
// 启动服务器(在 goroutine 中)
|
||||
go func() {
|
||||
log.Printf("AI Gateway 启动在端口 9826")
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号以优雅关闭服务器
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("正在关闭服务器...")
|
||||
|
||||
// 给服务器 5 秒时间完成当前请求
|
||||
zapLogger.Info("正在关闭服务器...")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Fatal("服务器强制关闭:", err)
|
||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
log.Println("服务器已关闭")
|
||||
zapLogger.Info("服务器已关闭")
|
||||
}
|
||||
|
||||
// setupRoutes 配置路由
|
||||
func setupRoutes(r *gin.Engine) {
|
||||
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
// 运行数据库迁移
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
||||
|
||||
// 记录连接池状态
|
||||
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations 使用 goose 执行数据库迁移
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir()
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
goose.SetDialect("sqlite3")
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMigrationsDir 获取迁移文件目录路径
|
||||
func getMigrationsDir() string {
|
||||
// 从可执行文件位置推断迁移目录
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
// cmd/server/main.go → backend/ → backend/migrations/
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
// 回退到相对路径
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
func closeDB(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func formatAddr(port int) string {
|
||||
return fmt.Sprintf(":%d", port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicHandler *handler.AnthropicHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
// OpenAI 协议代理
|
||||
openaiHandler := handler.NewOpenAIHandler()
|
||||
r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions)
|
||||
|
||||
// Anthropic 协议代理
|
||||
anthropicHandler := handler.NewAnthropicHandler()
|
||||
r.POST("/v1/messages", anthropicHandler.HandleMessages)
|
||||
|
||||
// 供应商管理 API
|
||||
providerHandler := handler.NewProviderHandler()
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
@@ -94,7 +223,6 @@ func setupRoutes(r *gin.Engine) {
|
||||
}
|
||||
|
||||
// 模型管理 API
|
||||
modelHandler := handler.NewModelHandler()
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
@@ -105,7 +233,6 @@ func setupRoutes(r *gin.Engine) {
|
||||
}
|
||||
|
||||
// 统计查询 API
|
||||
statsHandler := handler.NewStatsHandler()
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
|
||||
@@ -3,16 +3,29 @@ module nex/backend
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/pressly/goose/v3 v3.27.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/zap v1.27.1
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.12.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.2 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
@@ -22,20 +35,25 @@ require (
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gorm.io/driver/sqlite v1.6.0 // indirect
|
||||
gorm.io/gorm v1.31.1 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
@@ -7,24 +9,37 @@ github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCc
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||
github.com/go-playground/validator/v10 v10.30.2 h1:JiFIMtSSHb2/XBUbWM4i/MpeQm9ZK2xqPNk8vgvu5JQ=
|
||||
github.com/go-playground/validator/v10 v10.30.2/go.mod h1:mAf2pIOVXjTEBrwUMGKkCWKKPs9NheYGabeB04txQSc=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -33,24 +48,41 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
|
||||
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pressly/goose/v3 v3.27.0 h1:/D30gVTuQhu0WsNZYbJi4DMOsx1lNq+6SkLe+Wp59BM=
|
||||
github.com/pressly/goose/v3 v3.27.0/go.mod h1:3ZBeCXqzkgIRvrEMDkYh1guvtoJTU5oMMuDdkutoM78=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
|
||||
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -60,29 +92,68 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0=
|
||||
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/lumberjack.v2 v2.0.0 h1:IDj6hi8KbNiPQ5VaYNFZ7dBJLF5LFeKvsFrWHjA5aq4=
|
||||
gopkg.in/lumberjack.v2 v2.0.0/go.mod h1:bp5nQ2kK/lLQSmTk29azj9+JB6bWci56xFn/lvd5GLI=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
modernc.org/libc v1.68.0 h1:PJ5ikFOV5pwpW+VqCK1hKJuEWsonkIJhhIXyuF/91pQ=
|
||||
modernc.org/libc v1.68.0/go.mod h1:NnKCYeoYgsEqnY3PgvNgAeaJnso968ygU8Z0DxjoEc0=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
|
||||
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||
|
||||
@@ -1,24 +1,87 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
// LogConfig 日志配置
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Path string `yaml:"path"`
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
Compress bool `yaml:"compress"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns default config values
|
||||
func DefaultConfig() *Config {
|
||||
// Use home dir for default paths
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
Port: 9826,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Path: filepath.Join(nexDir, "config.db"),
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 1 * time.Hour,
|
||||
},
|
||||
Log: LogConfig{
|
||||
Level: "info",
|
||||
Path: filepath.Join(nexDir, "log"),
|
||||
MaxSize: 100,
|
||||
MaxBackups: 10,
|
||||
MaxAge: 30,
|
||||
Compress: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfigDir 获取配置目录路径(~/.nex/)
|
||||
func GetConfigDir() (string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
@@ -30,3 +93,79 @@ func GetDBPath() (string, error) {
|
||||
}
|
||||
return filepath.Join(configDir, "config.db"), nil
|
||||
}
|
||||
|
||||
// GetConfigPath 获取配置文件路径
|
||||
func GetConfigPath() (string, error) {
|
||||
configDir, err := GetConfigDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(configDir, "config.yaml"), nil
|
||||
}
|
||||
|
||||
// LoadConfig loads config from YAML file, creates default if not exists
|
||||
func LoadConfig() (*Config, error) {
|
||||
configPath, err := GetConfigPath()
|
||||
if err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Create default config file
|
||||
if saveErr := SaveConfig(cfg); saveErr != nil {
|
||||
return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// SaveConfig saves config to YAML file
|
||||
func SaveConfig(cfg *Config) error {
|
||||
configPath, err := GetConfigPath()
|
||||
if err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return os.WriteFile(configPath, data, 0644)
|
||||
}
|
||||
|
||||
// Validate validates the config
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port))
|
||||
}
|
||||
|
||||
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
|
||||
if !validLevels[c.Log.Level] {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level))
|
||||
}
|
||||
|
||||
if c.Database.Path == "" {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
176
backend/internal/config/config_test.go
Normal file
176
backend/internal/config/config_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
|
||||
assert.Equal(t, "info", cfg.Log.Level)
|
||||
assert.Equal(t, 100, cfg.Log.MaxSize)
|
||||
assert.Equal(t, 10, cfg.Log.MaxBackups)
|
||||
assert.Equal(t, 30, cfg.Log.MaxAge)
|
||||
assert.Equal(t, true, cfg.Log.Compress)
|
||||
}
|
||||
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modify func(*Config)
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "默认配置有效",
|
||||
modify: func(c *Config) {},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "端口号为0无效",
|
||||
modify: func(c *Config) { c.Server.Port = 0 },
|
||||
wantErr: true,
|
||||
errMsg: "无效的端口号",
|
||||
},
|
||||
{
|
||||
name: "端口号超出范围无效",
|
||||
modify: func(c *Config) { c.Server.Port = 70000 },
|
||||
wantErr: true,
|
||||
errMsg: "无效的端口号",
|
||||
},
|
||||
{
|
||||
name: "端口号为1有效",
|
||||
modify: func(c *Config) { c.Server.Port = 1 },
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "端口号为65535有效",
|
||||
modify: func(c *Config) { c.Server.Port = 65535 },
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "无效日志级别",
|
||||
modify: func(c *Config) { c.Log.Level = "invalid" },
|
||||
wantErr: true,
|
||||
errMsg: "无效的日志级别",
|
||||
},
|
||||
{
|
||||
name: "debug级别有效",
|
||||
modify: func(c *Config) { c.Log.Level = "debug" },
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "warn级别有效",
|
||||
modify: func(c *Config) { c.Log.Level = "warn" },
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "error级别有效",
|
||||
modify: func(c *Config) { c.Log.Level = "error" },
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数据库路径为空无效",
|
||||
modify: func(c *Config) { c.Database.Path = "" },
|
||||
wantErr: true,
|
||||
errMsg: "数据库路径不能为空",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
tt.modify(cfg)
|
||||
err := cfg.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConfigDir(t *testing.T) {
|
||||
dir, err := GetConfigDir()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dir)
|
||||
assert.Contains(t, dir, ".nex")
|
||||
}
|
||||
|
||||
func TestGetDBPath(t *testing.T) {
|
||||
path, err := GetDBPath()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, path)
|
||||
assert.Contains(t, path, "config.db")
|
||||
}
|
||||
|
||||
func TestGetConfigPath(t *testing.T) {
|
||||
path, err := GetConfigPath()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, path)
|
||||
assert.Contains(t, path, "config.yaml")
|
||||
}
|
||||
|
||||
func TestSaveAndLoadConfig(t *testing.T) {
|
||||
// 使用临时目录覆盖配置路径
|
||||
dir := t.TempDir()
|
||||
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Port: 9999,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 20 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 50,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
},
|
||||
Log: LogConfig{
|
||||
Level: "debug",
|
||||
Path: filepath.Join(dir, "log"),
|
||||
MaxSize: 50,
|
||||
MaxBackups: 5,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
},
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
data, err := yaml.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 加载配置
|
||||
data, err = os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
loaded := &Config{}
|
||||
err = yaml.Unmarshal(data, loaded)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Server.Port, loaded.Server.Port)
|
||||
assert.Equal(t, cfg.Log.Level, loaded.Log.Level)
|
||||
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
|
||||
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var db *gorm.DB
|
||||
|
||||
// InitDB 初始化数据库连接并创建表
|
||||
func InitDB() error {
|
||||
dbPath, err := GetDBPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取数据库路径失败: %w", err)
|
||||
}
|
||||
|
||||
// 打开数据库连接
|
||||
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 启用 WAL 模式以提升并发性能
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
// 自动迁移表结构
|
||||
if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil {
|
||||
return fmt.Errorf("创建表失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("数据库初始化成功: %s", dbPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDB 获取数据库连接
|
||||
func GetDB() *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// CloseDB 关闭数据库连接
|
||||
func CloseDB() error {
|
||||
if db != nil {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateModel 创建模型
|
||||
func CreateModel(model *Model) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
// 验证供应商是否存在
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", model.ProviderID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("供应商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
model.CreatedAt = time.Now()
|
||||
|
||||
return db.Create(model).Error
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
func GetModel(id string) (*Model, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var model Model
|
||||
err := db.First(&model, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
func ListModels(providerID string) ([]Model, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var models []Model
|
||||
var err error
|
||||
|
||||
if providerID != "" {
|
||||
err = db.Where("provider_id = ?", providerID).Find(&models).Error
|
||||
} else {
|
||||
err = db.Find(&models).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
func UpdateModel(id string, updates map[string]interface{}) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
// 如果更新了 provider_id,验证新供应商是否存在
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", providerID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("供应商不存在")
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
result := db.Model(&Model{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
func DeleteModel(id string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
result := db.Delete(&Model{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProvider 创建供应商
|
||||
func CreateProvider(provider *Provider) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
provider.CreatedAt = time.Now()
|
||||
provider.UpdatedAt = time.Now()
|
||||
|
||||
return db.Create(provider).Error
|
||||
}
|
||||
|
||||
// GetProvider 获取供应商
|
||||
func GetProvider(id string, maskKey bool) (*Provider, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var provider Provider
|
||||
err := db.First(&provider, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if maskKey {
|
||||
provider.MaskAPIKey()
|
||||
}
|
||||
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// ListProviders 列出所有供应商
|
||||
func ListProviders() ([]Provider, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var providers []Provider
|
||||
err := db.Find(&providers).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 掩码所有 API Key
|
||||
for i := range providers {
|
||||
providers[i].MaskAPIKey()
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// UpdateProvider 更新供应商
|
||||
func UpdateProvider(id string, updates map[string]interface{}) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
updates["updated_at"] = time.Now()
|
||||
|
||||
result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteProvider 删除供应商
|
||||
func DeleteProvider(id string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
result := db.Delete(&Provider{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RecordRequest 记录请求统计
|
||||
func RecordRequest(providerID, modelName string) error {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
today := time.Now().Format("2006-01-02")
|
||||
todayTime, _ := time.Parse("2006-01-02", today)
|
||||
|
||||
// 使用事务确保并发安全
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats UsageStats
|
||||
|
||||
// 查找或创建统计记录
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, todayTime).
|
||||
First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 创建新记录
|
||||
stats = UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
Date: todayTime,
|
||||
}
|
||||
return tx.Create(&stats).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新计数
|
||||
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
|
||||
})
|
||||
}
|
||||
|
||||
// GetStats 查询统计
|
||||
func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) {
|
||||
db := GetDB()
|
||||
if db == nil {
|
||||
return nil, errors.New("数据库未初始化")
|
||||
}
|
||||
|
||||
var stats []UsageStats
|
||||
query := db.Model(&UsageStats{})
|
||||
|
||||
if providerID != "" {
|
||||
query = query.Where("provider_id = ?", providerID)
|
||||
}
|
||||
|
||||
if modelName != "" {
|
||||
query = query.Where("model_name = ?", modelName)
|
||||
}
|
||||
|
||||
if startDate != nil {
|
||||
query = query.Where("date >= ?", startDate)
|
||||
}
|
||||
|
||||
if endDate != nil {
|
||||
query = query.Where("date <= ?", endDate)
|
||||
}
|
||||
|
||||
err := query.Order("date DESC").Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
12
backend/internal/domain/model.go
Normal file
12
backend/internal/domain/model.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// Model 模型领域模型
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
23
backend/internal/domain/provider.go
Normal file
23
backend/internal/domain/provider.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// Provider 供应商领域模型
|
||||
type Provider struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
7
backend/internal/domain/route.go
Normal file
7
backend/internal/domain/route.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package domain
|
||||
|
||||
// RouteResult 路由结果
|
||||
type RouteResult struct {
|
||||
Provider *Provider
|
||||
Model *Model
|
||||
}
|
||||
12
backend/internal/domain/stats.go
Normal file
12
backend/internal/domain/stats.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// UsageStats 用量统计领域模型
|
||||
type UsageStats struct {
|
||||
ID uint `json:"id"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
RequestCount int `json:"request_count"`
|
||||
Date time.Time `json:"date"`
|
||||
}
|
||||
@@ -7,30 +7,33 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/anthropic"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// AnthropicHandler Anthropic 协议处理器
|
||||
type AnthropicHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewAnthropicHandler 创建 Anthropic 处理器
|
||||
func NewAnthropicHandler() *AnthropicHandler {
|
||||
func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler {
|
||||
return &AnthropicHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleMessages 处理 Messages 请求
|
||||
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
// 解析 Anthropic 请求
|
||||
var req anthropic.MessagesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
@@ -43,7 +46,19 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查多模态内容
|
||||
// 请求验证
|
||||
if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil {
|
||||
errMsg := formatValidationErrors(validationErrors)
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: errMsg,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.checkMultimodalContent(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
@@ -55,7 +70,6 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 OpenAI 请求
|
||||
openaiReq, err := anthropic.ConvertRequest(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
@@ -68,14 +82,12 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(openaiReq.Model)
|
||||
routeResult, err := h.routingService.Route(openaiReq.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, openaiReq, routeResult)
|
||||
} else {
|
||||
@@ -83,9 +95,7 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
@@ -98,7 +108,6 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 响应
|
||||
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
@@ -111,18 +120,14 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
@@ -135,24 +140,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 创建流式转换器
|
||||
converter := anthropic.NewStreamConverter(
|
||||
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
|
||||
openaiReq.Model,
|
||||
)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -160,25 +160,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
|
||||
break
|
||||
}
|
||||
|
||||
// 解析 OpenAI 流块
|
||||
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
|
||||
if err != nil {
|
||||
fmt.Printf("解析流块失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换为 Anthropic 事件
|
||||
anthropicEvents, err := converter.ConvertChunk(chunk)
|
||||
if err != nil {
|
||||
fmt.Printf("转换事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 写入事件
|
||||
for _, ae := range anthropicEvents {
|
||||
eventStr, err := anthropic.SerializeEvent(ae)
|
||||
if err != nil {
|
||||
fmt.Printf("序列化事件失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
writer.WriteString(eventStr)
|
||||
@@ -186,13 +180,11 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
|
||||
}
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// checkMultimodalContent 检查多模态内容
|
||||
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
@@ -204,40 +196,22 @@ func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型未找到",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "模型已禁用",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: "供应商已禁用",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Message: appErr.Message,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
290
backend/internal/handler/handler_test.go
Normal file
290
backend/internal/handler/handler_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// ============ Mock 实现 ============
|
||||
|
||||
type mockRoutingService struct {
|
||||
result *domain.RouteResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockStatsService) Record(providerID, modelName string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return m.stats, nil
|
||||
}
|
||||
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
return m.aggrResult
|
||||
}
|
||||
|
||||
type mockProviderService struct {
|
||||
provider *domain.Provider
|
||||
providers []domain.Provider
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
return m.provider, m.err
|
||||
}
|
||||
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
|
||||
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockProviderService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockModelService struct {
|
||||
model *domain.Model
|
||||
models []domain.Model
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
|
||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||
return m.model, m.err
|
||||
}
|
||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
return m.models, m.err
|
||||
}
|
||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
resp *openai.ChatCompletionResponse
|
||||
eventChan chan provider.StreamEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
|
||||
return m.resp, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) {
|
||||
return m.eventChan, m.err
|
||||
}
|
||||
|
||||
// ============ OpenAI Handler 测试 ============
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid")))
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
|
||||
// 缺少 model 字段
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) {
|
||||
routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound}
|
||||
h := NewOpenAIHandler(nil, routingSvc, nil)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"model": "nonexistent",
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "p1"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
providers: []domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/providers", nil)
|
||||
|
||||
h.ListProviders(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("GET", "/api/providers/p1", nil)
|
||||
|
||||
h.GetProvider(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Model Handler 测试 ============
|
||||
|
||||
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "m1"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_ListModels(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
models: []domain.Model{
|
||||
{ID: "m1", ModelName: "gpt-4"},
|
||||
{ID: "m2", ModelName: "gpt-3.5"},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/models", nil)
|
||||
|
||||
h.ListModels(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Stats Handler 测试 ============
|
||||
|
||||
func TestStatsHandler_GetStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/stats", nil)
|
||||
|
||||
h.GetStats(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/stats?start_date=invalid", nil)
|
||||
|
||||
h.GetStats(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
},
|
||||
aggrResult: []map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil)
|
||||
|
||||
h.AggregateStats(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ writeError 测试 ============
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
writeError(c, appErrors.ErrModelNotFound)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestFormatValidationErrors(t *testing.T) {
|
||||
errs := map[string]string{
|
||||
"model": "模型名称不能为空",
|
||||
"messages": "消息列表不能为空",
|
||||
}
|
||||
result := formatValidationErrors(errs)
|
||||
require.Contains(t, result, "请求验证失败")
|
||||
require.Contains(t, result, "model")
|
||||
require.Contains(t, result, "messages")
|
||||
}
|
||||
21
backend/internal/handler/middleware/cors.go
Normal file
21
backend/internal/handler/middleware/cors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
40
backend/internal/handler/middleware/logging.go
Normal file
40
backend/internal/handler/middleware/logging.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logging 日志中间件
|
||||
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
logger.Info("请求开始",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.String("query", query),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.Any("request_id", requestID),
|
||||
)
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
logger.Info("请求结束",
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Duration("latency", latency),
|
||||
zap.Int("body_size", c.Writer.Size()),
|
||||
zap.Any("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
130
backend/internal/handler/middleware/middleware_test.go
Normal file
130
backend/internal/handler/middleware/middleware_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestRequestID_GeneratesUUID(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RequestID())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
id, exists := c.Get(RequestIDKey)
|
||||
assert.True(t, exists)
|
||||
assert.NotEmpty(t, id)
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.NotEmpty(t, w.Header().Get("X-Request-ID"))
|
||||
}
|
||||
|
||||
func TestRequestID_UsesExistingHeader(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(RequestID())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
id, _ := c.Get(RequestIDKey)
|
||||
assert.Equal(t, "existing-id-123", id)
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Request-ID", "existing-id-123")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "existing-id-123", w.Header().Get("X-Request-ID"))
|
||||
}
|
||||
|
||||
func TestLogging(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logging(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test?key=value", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestRecovery_NoPanic(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Recovery(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestRecovery_WithPanic(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Recovery(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestCORS_NormalRequest(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(CORS())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET")
|
||||
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "POST")
|
||||
}
|
||||
|
||||
func TestCORS_PreflightRequest(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(CORS())
|
||||
r.OPTIONS("/test", func(c *gin.Context) {
|
||||
c.Status(200)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("OPTIONS", "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 204, w.Code)
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
29
backend/internal/handler/middleware/recovery.go
Normal file
29
backend/internal/handler/middleware/recovery.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Recovery 错误恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
logger.Error("panic recovered",
|
||||
zap.Any("error", err),
|
||||
zap.Any("request_id", requestID),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Stack("stack"),
|
||||
)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "内部错误",
|
||||
})
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
21
backend/internal/handler/middleware/request_id.go
Normal file
21
backend/internal/handler/middleware/request_id.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const RequestIDKey = "request_id"
|
||||
|
||||
// RequestID 请求 ID 中间件
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
c.Set(RequestIDKey, requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -6,15 +6,20 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
type ModelHandler struct{}
|
||||
type ModelHandler struct {
|
||||
modelService service.ModelService
|
||||
}
|
||||
|
||||
// NewModelHandler 创建模型处理器
|
||||
func NewModelHandler() *ModelHandler {
|
||||
return &ModelHandler{}
|
||||
func NewModelHandler(modelService service.ModelService) *ModelHandler {
|
||||
return &ModelHandler{modelService: modelService}
|
||||
}
|
||||
|
||||
// CreateModel 创建模型
|
||||
@@ -32,26 +37,21 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建模型对象
|
||||
model := &config.Model{
|
||||
model := &domain.Model{
|
||||
ID: req.ID,
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateModel(model)
|
||||
err := h.modelService.Create(model)
|
||||
if err != nil {
|
||||
if err.Error() == "供应商不存在" {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -62,11 +62,9 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
providerID := c.Query("provider_id")
|
||||
|
||||
models, err := config.ListModels(providerID)
|
||||
models, err := h.modelService.List(providerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,7 +75,7 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
model, err := config.GetModel(id)
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -85,9 +83,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -106,8 +102,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新模型
|
||||
err := config.UpdateModel(id, req)
|
||||
err := h.modelService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -115,24 +110,19 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err.Error() == "供应商不存在" {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的模型
|
||||
model, err := config.GetModel(id)
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -143,7 +133,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
err := config.DeleteModel(id)
|
||||
err := h.modelService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -151,9 +141,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除模型失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -4,32 +4,36 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/router"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client *provider.Client
|
||||
router *router.Router
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler() *OpenAIHandler {
|
||||
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: provider.NewClient(),
|
||||
router: router.NewRouter(),
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
// 解析请求
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
@@ -41,14 +45,23 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 路由到供应商
|
||||
routeResult, err := h.router.Route(req.Model)
|
||||
// 请求验证
|
||||
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: formatValidationErrors(validationErrors),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据是否流式选择处理方式
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
@@ -56,9 +69,7 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamRequest 处理非流式请求
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送请求到供应商
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
@@ -70,18 +81,14 @@ func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatC
|
||||
return
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// handleStreamRequest 处理流式请求
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
|
||||
// 发送流式请求到供应商
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
@@ -93,75 +100,58 @@ func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatComp
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 SSE 响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
// 创建流写入器
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
// 流式转发事件
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
// 流错误,记录日志
|
||||
fmt.Printf("流错误: %v\n", event.Error)
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
// 流结束
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
// 写入事件数据
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
// 记录统计
|
||||
go func() {
|
||||
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleError 处理路由错误
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
switch err {
|
||||
case router.ErrModelNotFound:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型未找到",
|
||||
Message: appErr.Message,
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_not_found",
|
||||
},
|
||||
})
|
||||
case router.ErrModelDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "模型已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "model_disabled",
|
||||
},
|
||||
})
|
||||
case router.ErrProviderDisabled:
|
||||
c.JSON(http.StatusNotFound, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商已禁用",
|
||||
Type: "invalid_request_error",
|
||||
Code: "provider_disabled",
|
||||
},
|
||||
})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
Code: appErr.Code,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// formatValidationErrors 将验证错误 map 格式化为字符串
|
||||
func formatValidationErrors(errors map[string]string) string {
|
||||
parts := make([]string, 0, len(errors))
|
||||
for field, msg := range errors {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
@@ -2,19 +2,25 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
type ProviderHandler struct{}
|
||||
type ProviderHandler struct {
|
||||
providerService service.ProviderService
|
||||
}
|
||||
|
||||
// NewProviderHandler 创建供应商处理器
|
||||
func NewProviderHandler() *ProviderHandler {
|
||||
return &ProviderHandler{}
|
||||
func NewProviderHandler(providerService service.ProviderService) *ProviderHandler {
|
||||
return &ProviderHandler{providerService: providerService}
|
||||
}
|
||||
|
||||
// CreateProvider 创建供应商
|
||||
@@ -33,43 +39,34 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建供应商对象
|
||||
provider := &config.Provider{
|
||||
provider := &domain.Provider{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
Enabled: true, // 默认启用
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
err := config.CreateProvider(provider)
|
||||
err := h.providerService.Create(provider)
|
||||
if err != nil {
|
||||
// 检查是否是唯一约束错误(ID 重复)
|
||||
if err.Error() == "UNIQUE constraint failed: providers.id" {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "供应商 ID 已存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "创建供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 掩码 API Key 后返回
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
// ListProviders 列出所有供应商
|
||||
func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
providers, err := config.ListProviders()
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -80,7 +77,7 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := config.GetProvider(id, true) // 掩码 API Key
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -88,9 +85,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,8 +104,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新供应商
|
||||
err := config.UpdateProvider(id, req)
|
||||
err := h.providerService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -118,18 +112,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "更新供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的供应商
|
||||
provider, err := config.GetProvider(id, true)
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询更新后的供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,8 +129,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 删除供应商(级联删除模型)
|
||||
err := config.DeleteProvider(id)
|
||||
err := h.providerService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -149,19 +137,23 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "删除供应商失败: " + err.Error(),
|
||||
})
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 删除关联的模型
|
||||
models, _ := config.ListModels("")
|
||||
for _, model := range models {
|
||||
if model.ProviderID == id {
|
||||
_ = config.DeleteModel(model.ID)
|
||||
}
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// writeError 统一错误响应处理
|
||||
func writeError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, gin.H{
|
||||
"error": appErr.Message,
|
||||
"code": appErr.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,20 +6,21 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// StatsHandler 统计处理器
|
||||
type StatsHandler struct{}
|
||||
type StatsHandler struct {
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewStatsHandler 创建统计处理器
|
||||
func NewStatsHandler() *StatsHandler {
|
||||
return &StatsHandler{}
|
||||
func NewStatsHandler(statsService service.StatsService) *StatsHandler {
|
||||
return &StatsHandler{statsService: statsService}
|
||||
}
|
||||
|
||||
// GetStats 查询统计
|
||||
func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
@@ -27,7 +28,6 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
@@ -50,8 +50,7 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
@@ -64,16 +63,14 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
|
||||
|
||||
// AggregateStats 聚合统计
|
||||
func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
providerID := c.Query("provider_id")
|
||||
modelName := c.Query("model_name")
|
||||
startDateStr := c.Query("start_date")
|
||||
endDateStr := c.Query("end_date")
|
||||
groupBy := c.Query("group_by") // "provider", "model", "date"
|
||||
groupBy := c.Query("group_by")
|
||||
|
||||
var startDate, endDate *time.Time
|
||||
|
||||
// 解析日期
|
||||
if startDateStr != "" {
|
||||
t, err := time.Parse("2006-01-02", startDateStr)
|
||||
if err != nil {
|
||||
@@ -96,8 +93,7 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
endDate = &t
|
||||
}
|
||||
|
||||
// 查询统计
|
||||
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
|
||||
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "查询统计失败: " + err.Error(),
|
||||
@@ -105,80 +101,6 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 聚合
|
||||
result := h.aggregate(stats, groupBy)
|
||||
|
||||
result := h.statsService.Aggregate(stats, groupBy)
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// aggregate 执行聚合
|
||||
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
|
||||
switch groupBy {
|
||||
case "provider":
|
||||
return h.aggregateByProvider(stats)
|
||||
case "model":
|
||||
return h.aggregateByModel(stats)
|
||||
case "date":
|
||||
return h.aggregateByDate(stats)
|
||||
default:
|
||||
// 默认按供应商聚合
|
||||
return h.aggregateByProvider(stats)
|
||||
}
|
||||
}
|
||||
|
||||
// aggregateByProvider 按供应商聚合
|
||||
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
aggregated[stat.ProviderID] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for providerID, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": providerID,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByModel 按模型聚合
|
||||
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.ProviderID + "/" + stat.ModelName
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for key, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": key[:len(key)/2],
|
||||
"model_name": key[len(key)/2+1:],
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// aggregateByDate 按日期聚合
|
||||
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.Date.Format("2006-01-02")
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for date, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
270
backend/internal/protocol/anthropic/converter_test.go
Normal file
270
backend/internal/protocol/anthropic/converter_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
func TestConvertRequest_Basic(t *testing.T) {
|
||||
temp := 0.7
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 1024,
|
||||
Temperature: &temp,
|
||||
Messages: []AnthropicMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{
|
||||
{Type: "text", Text: "Hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3-opus", result.Model)
|
||||
assert.Equal(t, 1024, *result.MaxTokens)
|
||||
assert.Equal(t, &temp, result.Temperature)
|
||||
require.Len(t, result.Messages, 1)
|
||||
assert.Equal(t, "user", result.Messages[0].Role)
|
||||
assert.Equal(t, "Hello", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
func TestConvertRequest_WithSystem(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
System: "You are a helpful assistant.",
|
||||
Messages: []AnthropicMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{{Type: "text", Text: "Hi"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Messages, 2)
|
||||
assert.Equal(t, "system", result.Messages[0].Role)
|
||||
assert.Equal(t, "You are a helpful assistant.", result.Messages[0].Content)
|
||||
assert.Equal(t, "user", result.Messages[1].Role)
|
||||
}
|
||||
|
||||
func TestConvertRequest_DefaultMaxTokens(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 0, // 未设置
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 4096, *result.MaxTokens)
|
||||
}
|
||||
|
||||
func TestConvertRequest_WithTools(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
Tools: []AnthropicTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather info",
|
||||
InputSchema: map[string]interface{}{"type": "object"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Tools, 1)
|
||||
assert.Equal(t, "function", result.Tools[0].Type)
|
||||
assert.Equal(t, "get_weather", result.Tools[0].Function.Name)
|
||||
}
|
||||
|
||||
func TestConvertRequest_WithStopSequences(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"STOP", "END"}, result.Stop)
|
||||
}
|
||||
|
||||
func TestConvertRequest_ToolResult(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []ContentBlock{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseID: "tool_123",
|
||||
Content: "result data",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := ConvertRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Messages, 1)
|
||||
assert.Equal(t, "tool", result.Messages[0].Role)
|
||||
assert.Equal(t, "tool_123", result.Messages[0].ToolCallID)
|
||||
assert.Equal(t, "result data", result.Messages[0].Content)
|
||||
}
|
||||
|
||||
func TestConvertResponse(t *testing.T) {
|
||||
resp := &openai.ChatCompletionResponse{
|
||||
ID: "chatcmpl-123",
|
||||
Model: "gpt-4",
|
||||
Choices: []openai.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &openai.Message{Role: "assistant", Content: "Hello!"},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
Usage: openai.Usage{PromptTokens: 10, CompletionTokens: 5},
|
||||
}
|
||||
|
||||
result, err := ConvertResponse(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
assert.Equal(t, "message", result.Type)
|
||||
assert.Equal(t, "assistant", result.Role)
|
||||
assert.Equal(t, "end_turn", result.StopReason)
|
||||
require.Len(t, result.Content, 1)
|
||||
assert.Equal(t, "text", result.Content[0].Type)
|
||||
assert.Equal(t, "Hello!", result.Content[0].Text)
|
||||
assert.Equal(t, 10, result.Usage.InputTokens)
|
||||
assert.Equal(t, 5, result.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestConvertResponse_ToolCalls(t *testing.T) {
|
||||
args, _ := json.Marshal(map[string]interface{}{"city": "Beijing"})
|
||||
resp := &openai.ChatCompletionResponse{
|
||||
ID: "chatcmpl-456",
|
||||
Model: "gpt-4",
|
||||
Choices: []openai.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &openai.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []openai.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: "get_weather",
|
||||
Arguments: string(args),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
FinishReason: "tool_calls",
|
||||
},
|
||||
},
|
||||
Usage: openai.Usage{},
|
||||
}
|
||||
|
||||
result, err := ConvertResponse(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "tool_use", result.StopReason)
|
||||
require.Len(t, result.Content, 1)
|
||||
assert.Equal(t, "tool_use", result.Content[0].Type)
|
||||
assert.Equal(t, "call_123", result.Content[0].ID)
|
||||
assert.Equal(t, "get_weather", result.Content[0].Name)
|
||||
}
|
||||
|
||||
func TestConvertToolChoice_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantErr bool
|
||||
check func(interface{})
|
||||
}{
|
||||
{"auto字符串", "auto", false, func(r interface{}) { assert.Equal(t, "auto", r) }},
|
||||
{"any字符串", "any", false, func(r interface{}) { assert.Equal(t, "auto", r) }},
|
||||
{"无效字符串", "invalid", true, nil},
|
||||
{"tool对象", map[string]interface{}{"type": "tool", "name": "my_func"}, false,
|
||||
func(r interface{}) {
|
||||
m := r.(map[string]interface{})
|
||||
assert.Equal(t, "function", m["type"])
|
||||
}},
|
||||
{"缺少name的tool对象", map[string]interface{}{"type": "tool"}, true, nil},
|
||||
{"缺少type的对象", map[string]interface{}{"name": "func"}, true, nil},
|
||||
{"无效类型", 42, true, nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := convertToolChoice(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
tt.check(result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRequest(t *testing.T) {
|
||||
t.Run("有效请求", func(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.Nil(t, errs)
|
||||
})
|
||||
|
||||
t.Run("缺少模型", func(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
MaxTokens: 100,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
assert.Contains(t, errs["model"], "不能为空")
|
||||
})
|
||||
|
||||
t.Run("MaxTokens为0", func(t *testing.T) {
|
||||
req := &MessagesRequest{
|
||||
Model: "claude-3-opus",
|
||||
MaxTokens: 0,
|
||||
Messages: []AnthropicMessage{
|
||||
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
|
||||
},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
})
|
||||
}
|
||||
229
backend/internal/protocol/anthropic/stream_converter_test.go
Normal file
229
backend/internal/protocol/anthropic/stream_converter_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
func TestStreamConverter_MessageStart(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
ID: "chatcmpl-123",
|
||||
Choices: []openai.StreamChoice{{Index: 0, Delta: openai.Delta{}}},
|
||||
}
|
||||
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
// 第一个事件应该是 message_start
|
||||
assert.Equal(t, "message_start", events[0].Type)
|
||||
require.NotNil(t, events[0].Message)
|
||||
assert.Equal(t, "msg_123", events[0].Message.ID)
|
||||
assert.Equal(t, "message", events[0].Message.Type)
|
||||
assert.Equal(t, "assistant", events[0].Message.Role)
|
||||
assert.Equal(t, "claude-3-opus", events[0].Message.Model)
|
||||
}
|
||||
|
||||
func TestStreamConverter_TextDelta(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
// 先发送一个空块以触发 message_start
|
||||
chunk1 := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{Content: "Hello"}},
|
||||
},
|
||||
}
|
||||
events1, err := converter.ConvertChunk(chunk1)
|
||||
require.NoError(t, err)
|
||||
// 应有 message_start + content_block_start + text delta
|
||||
assert.GreaterOrEqual(t, len(events1), 3)
|
||||
|
||||
// 第二个文本块不应再发送 message_start 和 content_block_start
|
||||
chunk2 := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{Content: " world"}},
|
||||
},
|
||||
}
|
||||
events2, err := converter.ConvertChunk(chunk2)
|
||||
require.NoError(t, err)
|
||||
// 只有 text delta
|
||||
assert.Len(t, events2, 1)
|
||||
assert.Equal(t, "content_block_delta", events2[0].Type)
|
||||
assert.Equal(t, "text_delta", events2[0].Delta.Type)
|
||||
assert.Equal(t, " world", events2[0].Delta.Text)
|
||||
}
|
||||
|
||||
func TestStreamConverter_FinishReason(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{Content: "Hello"}, FinishReason: "stop"},
|
||||
},
|
||||
}
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 查找 message_delta 事件
|
||||
var messageDelta *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_delta" {
|
||||
messageDelta = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, messageDelta)
|
||||
assert.Equal(t, "end_turn", messageDelta.Delta.StopReason)
|
||||
|
||||
// 查找 message_stop 事件
|
||||
var messageStop *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_stop" {
|
||||
messageStop = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, messageStop)
|
||||
}
|
||||
|
||||
func TestStreamConverter_FinishReasonToolCalls(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{}, FinishReason: "tool_calls"},
|
||||
},
|
||||
}
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
var messageDelta *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_delta" {
|
||||
messageDelta = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, messageDelta)
|
||||
assert.Equal(t, "tool_use", messageDelta.Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamConverter_FinishReasonLength(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{Delta: openai.Delta{}, FinishReason: "length"},
|
||||
},
|
||||
}
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
var messageDelta *StreamEvent
|
||||
for _, e := range events {
|
||||
if e.Type == "message_delta" {
|
||||
messageDelta = &e
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, messageDelta)
|
||||
assert.Equal(t, "max_tokens", messageDelta.Delta.StopReason)
|
||||
}
|
||||
|
||||
func TestStreamConverter_ToolCalls(t *testing.T) {
|
||||
converter := NewStreamConverter("msg_123", "claude-3-opus")
|
||||
|
||||
chunk := &openai.StreamChunk{
|
||||
Choices: []openai.StreamChoice{
|
||||
{
|
||||
Delta: openai.Delta{
|
||||
ToolCalls: []openai.ToolCall{
|
||||
{
|
||||
ID: "call_123",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: "get_weather",
|
||||
Arguments: `{"city": "Beijing"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events, err := converter.ConvertChunk(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 应包含 content_block_start (tool_use) + content_block_delta (input_json_delta)
|
||||
hasBlockStart := false
|
||||
hasInputDelta := false
|
||||
for _, e := range events {
|
||||
if e.Type == "content_block_start" && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
hasBlockStart = true
|
||||
assert.Equal(t, "call_123", e.ContentBlock.ID)
|
||||
assert.Equal(t, "get_weather", e.ContentBlock.Name)
|
||||
}
|
||||
if e.Type == "content_block_delta" && e.Delta != nil && e.Delta.Type == "input_json_delta" {
|
||||
hasInputDelta = true
|
||||
assert.Equal(t, `{"city": "Beijing"}`, e.Delta.Input)
|
||||
}
|
||||
}
|
||||
assert.True(t, hasBlockStart, "应有 tool_use content_block_start")
|
||||
assert.True(t, hasInputDelta, "应有 input_json_delta")
|
||||
}
|
||||
|
||||
func TestSerializeEvent(t *testing.T) {
|
||||
event := StreamEvent{
|
||||
Type: "message_start",
|
||||
Message: &MessagesResponse{
|
||||
ID: "msg_123",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := SerializeEvent(event)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result, "event: message_start")
|
||||
assert.Contains(t, result, "data: ")
|
||||
assert.Contains(t, result, "msg_123")
|
||||
}
|
||||
|
||||
func TestSerializeEvent_InvalidJSON(t *testing.T) {
|
||||
event := StreamEvent{
|
||||
Type: "test",
|
||||
}
|
||||
// 这个应该能正常序列化
|
||||
result, err := SerializeEvent(event)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result, "event: test")
|
||||
}
|
||||
|
||||
func TestContentBlock_ParseInputJSON(t *testing.T) {
|
||||
t.Run("字符串输入", func(t *testing.T) {
|
||||
cb := &ContentBlock{Input: `{"key": "value"}`}
|
||||
result, err := cb.ParseInputJSON()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "value", result["key"])
|
||||
})
|
||||
|
||||
t.Run("对象输入", func(t *testing.T) {
|
||||
cb := &ContentBlock{Input: map[string]interface{}{"key": "value"}}
|
||||
result, err := cb.ParseInputJSON()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "value", result["key"])
|
||||
})
|
||||
|
||||
t.Run("无效类型", func(t *testing.T) {
|
||||
cb := &ContentBlock{Input: 42}
|
||||
_, err := cb.ParseInputJSON()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1,13 +1,20 @@
|
||||
package anthropic
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
|
||||
pkgValidator "nex/backend/pkg/validator"
|
||||
)
|
||||
|
||||
// MessagesRequest Anthropic Messages API 请求结构
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []AnthropicMessage `json:"messages"`
|
||||
Model string `json:"model" validate:"required"`
|
||||
Messages []AnthropicMessage `json:"messages" validate:"required,min=1"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MaxTokens int `json:"max_tokens" validate:"required,min=1"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
@@ -114,5 +121,29 @@ func (cb *ContentBlock) ParseInputJSON() (map[string]interface{}, error) {
|
||||
if obj, ok := cb.Input.(map[string]interface{}); ok {
|
||||
return obj, nil
|
||||
}
|
||||
return nil, json.Unmarshal([]byte{}, nil) // 返回错误
|
||||
return nil, fmt.Errorf("invalid input type: expected string or map")
|
||||
}
|
||||
|
||||
// ValidateRequest 验证 MessagesRequest
|
||||
func ValidateRequest(req *MessagesRequest) map[string]string {
|
||||
errs := pkgValidator.Validate(req)
|
||||
if errs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
validationErrors := make(map[string]string)
|
||||
for _, err := range errs.(validator.ValidationErrors) {
|
||||
field := err.Field()
|
||||
switch field {
|
||||
case "Model":
|
||||
validationErrors["model"] = "模型名称不能为空"
|
||||
case "Messages":
|
||||
validationErrors["messages"] = "消息列表不能为空"
|
||||
case "MaxTokens":
|
||||
validationErrors["max_tokens"] = "max_tokens 不能为空且必须大于 0"
|
||||
default:
|
||||
validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag())
|
||||
}
|
||||
}
|
||||
return validationErrors
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
@@ -24,9 +23,6 @@ func (a *Adapter) PrepareRequest(req *ChatCompletionRequest, apiKey, baseURL str
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调试日志:打印请求体
|
||||
fmt.Printf("[DEBUG] 请求Body: %s\n", string(body))
|
||||
|
||||
// 创建 HTTP 请求
|
||||
// baseURL 已包含版本路径(如 /v1 或 /v4),只需添加端点路径
|
||||
httpReq, err := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
|
||||
190
backend/internal/protocol/openai/adapter_test.go
Normal file
190
backend/internal/protocol/openai/adapter_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_PrepareRequest(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
}
|
||||
|
||||
httpReq, err := adapter.PrepareRequest(req, "test-api-key", "https://api.openai.com/v1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, httpReq)
|
||||
|
||||
assert.Equal(t, "POST", httpReq.Method)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", httpReq.URL.String())
|
||||
assert.Equal(t, "application/json", httpReq.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "Bearer test-api-key", httpReq.Header.Get("Authorization"))
|
||||
|
||||
// 验证请求体
|
||||
var body ChatCompletionRequest
|
||||
err = json.NewDecoder(httpReq.Body).Decode(&body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", body.Model)
|
||||
}
|
||||
|
||||
func TestAdapter_ParseResponse(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
resp := &ChatCompletionResponse{
|
||||
ID: "chatcmpl-123",
|
||||
Object: "chat.completion",
|
||||
Created: 1234567890,
|
||||
Model: "gpt-4",
|
||||
Choices: []Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &Message{Role: "assistant", Content: "Hello!"},
|
||||
},
|
||||
},
|
||||
Usage: Usage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpResp := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
result, err := adapter.ParseResponse(httpResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model)
|
||||
require.Len(t, result.Choices, 1)
|
||||
assert.Equal(t, "Hello!", result.Choices[0].Message.Content)
|
||||
}
|
||||
|
||||
func TestAdapter_ParseErrorResponse(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
errResp := &ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Invalid API key",
|
||||
Type: "invalid_request_error",
|
||||
Code: "invalid_api_key",
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(errResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpResp := &http.Response{
|
||||
StatusCode: 401,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
result, err := adapter.ParseErrorResponse(httpResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Invalid API key", result.Error.Message)
|
||||
assert.Equal(t, "invalid_request_error", result.Error.Type)
|
||||
}
|
||||
|
||||
func TestAdapter_ParseStreamChunk(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
chunk := &StreamChunk{
|
||||
ID: "chatcmpl-123",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1234567890,
|
||||
Model: "gpt-4",
|
||||
Choices: []StreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: Delta{Content: "Hello"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := adapter.ParseStreamChunk(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
require.Len(t, result.Choices, 1)
|
||||
assert.Equal(t, "Hello", result.Choices[0].Delta.Content)
|
||||
}
|
||||
|
||||
func TestParseToolCallArguments(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"有效JSON", `{"key": "value"}`, false},
|
||||
{"无效JSON", `not json`, true},
|
||||
{"空JSON", `{}`, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := &ToolCall{
|
||||
Function: FunctionCall{Arguments: tt.input},
|
||||
}
|
||||
args, err := tc.ParseToolCallArguments()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, args)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeToolCallArguments(t *testing.T) {
|
||||
args := map[string]interface{}{"key": "value"}
|
||||
result, err := SerializeToolCallArguments(args)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, `{"key": "value"}`, result)
|
||||
}
|
||||
|
||||
func TestValidateRequest(t *testing.T) {
|
||||
t.Run("有效请求", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []Message{{Role: "user", Content: "hello"}},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.Nil(t, errs)
|
||||
})
|
||||
|
||||
t.Run("缺少模型", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Messages: []Message{{Role: "user", Content: "hello"}},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
assert.Contains(t, errs["model"], "不能为空")
|
||||
})
|
||||
|
||||
t.Run("缺少消息", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
assert.Contains(t, errs["messages"], "不能为空")
|
||||
})
|
||||
|
||||
t.Run("空消息列表", func(t *testing.T) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []Message{},
|
||||
}
|
||||
errs := ValidateRequest(req)
|
||||
assert.NotNil(t, errs)
|
||||
})
|
||||
}
|
||||
@@ -1,11 +1,18 @@
|
||||
package openai
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
|
||||
pkgValidator "nex/backend/pkg/validator"
|
||||
)
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completions API 请求结构
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Model string `json:"model" validate:"required"`
|
||||
Messages []Message `json:"messages" validate:"required,min=1"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
@@ -129,3 +136,25 @@ func SerializeToolCallArguments(args map[string]interface{}) (string, error) {
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// ValidateRequest 验证 ChatCompletionRequest
|
||||
func ValidateRequest(req *ChatCompletionRequest) map[string]string {
|
||||
errs := pkgValidator.Validate(req)
|
||||
if errs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
validationErrors := make(map[string]string)
|
||||
for _, err := range errs.(validator.ValidationErrors) {
|
||||
field := err.Field()
|
||||
switch field {
|
||||
case "Model":
|
||||
validationErrors["model"] = "模型名称不能为空"
|
||||
case "Messages":
|
||||
validationErrors["messages"] = "消息列表不能为空"
|
||||
default:
|
||||
validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag())
|
||||
}
|
||||
}
|
||||
return validationErrors
|
||||
}
|
||||
|
||||
@@ -9,22 +9,59 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
// StreamConfig 流式处理配置
|
||||
type StreamConfig struct {
|
||||
InitialBufferSize int // 初始缓冲区大小(字节),默认 4096
|
||||
MaxBufferSize int // 最大缓冲区大小(字节),默认 65536
|
||||
Timeout time.Duration // 流超时时间,默认 5 分钟
|
||||
ChannelBufferSize int // 事件通道缓冲区大小,默认 100
|
||||
}
|
||||
|
||||
// DefaultStreamConfig 返回默认流式处理配置
|
||||
func DefaultStreamConfig() StreamConfig {
|
||||
return StreamConfig{
|
||||
InitialBufferSize: 4096,
|
||||
MaxBufferSize: 65536,
|
||||
Timeout: 5 * time.Minute,
|
||||
ChannelBufferSize: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Client OpenAI 兼容供应商客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
adapter *openai.Adapter
|
||||
httpClient *http.Client
|
||||
adapter *openai.Adapter
|
||||
logger *zap.Logger
|
||||
streamCfg StreamConfig
|
||||
}
|
||||
|
||||
// StreamEvent 流事件
|
||||
type StreamEvent struct {
|
||||
Data []byte
|
||||
Error error
|
||||
Done bool
|
||||
}
|
||||
|
||||
// ProviderClient 供应商客户端接口
|
||||
type ProviderClient interface {
|
||||
SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error)
|
||||
SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error)
|
||||
}
|
||||
|
||||
// NewClient 创建供应商客户端
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second, // 非流式请求超时
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
adapter: openai.NewAdapter(),
|
||||
adapter: openai.NewAdapter(),
|
||||
logger: zap.L(),
|
||||
streamCfg: DefaultStreamConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,10 +73,10 @@ func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequ
|
||||
return nil, fmt.Errorf("准备请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 调试日志:打印完整请求信息
|
||||
fmt.Printf("[DEBUG] 请求URL: %s\n", httpReq.URL.String())
|
||||
fmt.Printf("[DEBUG] 请求Method: %s\n", httpReq.Method)
|
||||
fmt.Printf("[DEBUG] 请求Headers: %v\n", httpReq.Header)
|
||||
c.logger.Debug("发送请求",
|
||||
zap.String("url", httpReq.URL.String()),
|
||||
zap.String("method", httpReq.Method),
|
||||
)
|
||||
|
||||
// 设置上下文
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
@@ -80,18 +117,22 @@ func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompleti
|
||||
return nil, fmt.Errorf("准备请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置上下文
|
||||
httpReq = httpReq.WithContext(ctx)
|
||||
// 设置带超时的上下文
|
||||
streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout)
|
||||
_ = cancel // cancel 在流读取结束后由 ctx 传播处理
|
||||
httpReq = httpReq.WithContext(streamCtx)
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
cancel()
|
||||
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
@@ -100,33 +141,33 @@ func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompleti
|
||||
}
|
||||
|
||||
// 创建事件通道
|
||||
eventChan := make(chan StreamEvent, 100)
|
||||
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
|
||||
|
||||
// 启动 goroutine 读取流
|
||||
go c.readStream(ctx, resp.Body, eventChan)
|
||||
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
|
||||
|
||||
return eventChan, nil
|
||||
}
|
||||
|
||||
// StreamEvent 流事件
|
||||
type StreamEvent struct {
|
||||
Data []byte
|
||||
Error error
|
||||
Done bool
|
||||
}
|
||||
|
||||
// readStream 读取 SSE 流
|
||||
func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan chan<- StreamEvent) {
|
||||
// readStream 读取 SSE 流(支持动态缓冲区、超时控制和改进的错误处理)
|
||||
func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) {
|
||||
defer close(eventChan)
|
||||
defer body.Close()
|
||||
defer cancel()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
bufSize := c.streamCfg.InitialBufferSize
|
||||
buf := make([]byte, bufSize)
|
||||
var dataBuf []byte
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
eventChan <- StreamEvent{Error: ctx.Err()}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
c.logger.Warn("流读取超时")
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("流读取超时: %w", ctx.Err())}
|
||||
} else {
|
||||
eventChan <- StreamEvent{Error: ctx.Err()}
|
||||
}
|
||||
return
|
||||
default:
|
||||
}
|
||||
@@ -134,15 +175,32 @@ func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan c
|
||||
n, err := body.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// 流结束
|
||||
// 流正常结束
|
||||
return
|
||||
}
|
||||
eventChan <- StreamEvent{Error: err}
|
||||
// 区分网络错误和其他错误
|
||||
if isNetworkError(err) {
|
||||
c.logger.Error("流网络错误", zap.String("error", err.Error()))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
|
||||
} else {
|
||||
c.logger.Error("流读取错误", zap.String("error", err.Error()))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
dataBuf = append(dataBuf, buf[:n]...)
|
||||
|
||||
// 动态调整缓冲区大小:如果数据量大,增大缓冲区
|
||||
if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize {
|
||||
newSize := bufSize * 2
|
||||
if newSize > c.streamCfg.MaxBufferSize {
|
||||
newSize = c.streamCfg.MaxBufferSize
|
||||
}
|
||||
buf = make([]byte, newSize)
|
||||
bufSize = newSize
|
||||
}
|
||||
|
||||
// 处理完整的 SSE 事件
|
||||
for {
|
||||
// 查找事件边界(双换行)
|
||||
@@ -175,3 +233,16 @@ func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan c
|
||||
}
|
||||
}
|
||||
|
||||
// isNetworkError 判断是否为网络相关错误
|
||||
func isNetworkError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "network") ||
|
||||
strings.Contains(errStr, "timeout") ||
|
||||
strings.Contains(errStr, "EOF")
|
||||
}
|
||||
|
||||
|
||||
151
backend/internal/provider/client_test.go
Normal file
151
backend/internal/provider/client_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/protocol/openai"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient()
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.httpClient)
|
||||
assert.NotNil(t, client.adapter)
|
||||
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
|
||||
assert.Equal(t, 65536, client.streamCfg.MaxBufferSize)
|
||||
assert.Equal(t, 100, client.streamCfg.ChannelBufferSize)
|
||||
}
|
||||
|
||||
func TestDefaultStreamConfig(t *testing.T) {
|
||||
cfg := DefaultStreamConfig()
|
||||
assert.Equal(t, 4096, cfg.InitialBufferSize)
|
||||
assert.Equal(t, 65536, cfg.MaxBufferSize)
|
||||
assert.Equal(t, 100, cfg.ChannelBufferSize)
|
||||
}
|
||||
|
||||
func TestClient_SendRequest_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
||||
|
||||
resp := openai.ChatCompletionResponse{
|
||||
ID: "chatcmpl-123",
|
||||
Choices: []openai.Choice{
|
||||
{Index: 0, Message: &openai.Message{Role: "assistant", Content: "Hello!"}},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
}
|
||||
|
||||
result, err := client.SendRequest(context.Background(), req, "test-key", server.URL)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", result.ID)
|
||||
}
|
||||
|
||||
func TestClient_SendRequest_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{Message: "Invalid API key"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
}
|
||||
|
||||
_, err := client.SendRequest(context.Background(), req, "bad-key", server.URL)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Invalid API key")
|
||||
}
|
||||
|
||||
func TestClient_SendRequest_ConnectionError(t *testing.T) {
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
}
|
||||
|
||||
_, err := client.SendRequest(context.Background(), req, "key", "http://localhost:1")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) {
|
||||
// 使用一个慢服务器确保客户端有时间读取
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStreamRequest(context.Background(), req, "test-key", server.URL)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, eventChan)
|
||||
|
||||
// 读取直到 channel 关闭(服务器关闭后应产生 EOF)
|
||||
for range eventChan {
|
||||
// 消费所有事件
|
||||
}
|
||||
// channel 应已关闭(不阻塞即通过)
|
||||
}
|
||||
|
||||
func TestClient_SendStreamRequest_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
req := &openai.ChatCompletionRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
|
||||
}
|
||||
|
||||
_, err := client.SendStreamRequest(context.Background(), req, "key", server.URL)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsNetworkError(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"connection reset by peer", true},
|
||||
{"broken pipe", true},
|
||||
{"network is unreachable", true},
|
||||
{"timeout waiting for response", true},
|
||||
{"unexpected EOF", true},
|
||||
{"normal error", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
err := fmt.Errorf("%s", tt.input) //nolint:govet
|
||||
assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input)
|
||||
}
|
||||
}
|
||||
13
backend/internal/repository/model_repo.go
Normal file
13
backend/internal/repository/model_repo.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
// ModelRepository 模型数据仓库接口
|
||||
type ModelRepository interface {
|
||||
Create(model *domain.Model) error
|
||||
GetByID(id string) (*domain.Model, error)
|
||||
List(providerID string) ([]domain.Model, error)
|
||||
GetByModelName(modelName string) (*domain.Model, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
}
|
||||
104
backend/internal/repository/model_repo_impl.go
Normal file
104
backend/internal/repository/model_repo_impl.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type modelRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewModelRepository(db *gorm.DB) ModelRepository {
|
||||
return &modelRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *modelRepository) Create(model *domain.Model) error {
|
||||
m := toConfigModel(model)
|
||||
m.CreatedAt = time.Now()
|
||||
return r.db.Create(&m).Error
|
||||
}
|
||||
|
||||
func (r *modelRepository) GetByID(id string) (*domain.Model, error) {
|
||||
var m config.Model
|
||||
err := r.db.First(&m, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := toDomainModel(&m)
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (r *modelRepository) List(providerID string) ([]domain.Model, error) {
|
||||
var models []config.Model
|
||||
var err error
|
||||
if providerID != "" {
|
||||
err = r.db.Where("provider_id = ?", providerID).Find(&models).Error
|
||||
} else {
|
||||
err = r.db.Find(&models).Error
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]domain.Model, len(models))
|
||||
for i := range models {
|
||||
result[i] = toDomainModel(&models[i])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) {
|
||||
var m config.Model
|
||||
err := r.db.Where("model_name = ?", modelName).First(&m).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := toDomainModel(&m)
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
|
||||
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *modelRepository) Delete(id string) error {
|
||||
result := r.db.Delete(&config.Model{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toDomainModel(m *config.Model) domain.Model {
|
||||
return domain.Model{
|
||||
ID: m.ID,
|
||||
ProviderID: m.ProviderID,
|
||||
ModelName: m.ModelName,
|
||||
Enabled: m.Enabled,
|
||||
CreatedAt: m.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func toConfigModel(m *domain.Model) config.Model {
|
||||
return config.Model{
|
||||
ID: m.ID,
|
||||
ProviderID: m.ProviderID,
|
||||
ModelName: m.ModelName,
|
||||
Enabled: m.Enabled,
|
||||
}
|
||||
}
|
||||
12
backend/internal/repository/provider_repo.go
Normal file
12
backend/internal/repository/provider_repo.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
// ProviderRepository 供应商数据仓库接口
|
||||
type ProviderRepository interface {
|
||||
Create(provider *domain.Provider) error
|
||||
GetByID(id string) (*domain.Provider, error)
|
||||
List() ([]domain.Provider, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
}
|
||||
94
backend/internal/repository/provider_repo_impl.go
Normal file
94
backend/internal/repository/provider_repo_impl.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type providerRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewProviderRepository(db *gorm.DB) ProviderRepository {
|
||||
return &providerRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *providerRepository) Create(provider *domain.Provider) error {
|
||||
p := toConfigProvider(provider)
|
||||
p.CreatedAt = time.Now()
|
||||
p.UpdatedAt = time.Now()
|
||||
return r.db.Create(&p).Error
|
||||
}
|
||||
|
||||
func (r *providerRepository) GetByID(id string) (*domain.Provider, error) {
|
||||
var p config.Provider
|
||||
err := r.db.First(&p, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := toDomainProvider(&p)
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (r *providerRepository) List() ([]domain.Provider, error) {
|
||||
var providers []config.Provider
|
||||
err := r.db.Find(&providers).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]domain.Provider, len(providers))
|
||||
for i := range providers {
|
||||
result[i] = toDomainProvider(&providers[i])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *providerRepository) Update(id string, updates map[string]interface{}) error {
|
||||
updates["updated_at"] = time.Now()
|
||||
result := r.db.Model(&config.Provider{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *providerRepository) Delete(id string) error {
|
||||
result := r.db.Delete(&config.Provider{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toDomainProvider(p *config.Provider) domain.Provider {
|
||||
return domain.Provider{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
APIKey: p.APIKey,
|
||||
BaseURL: p.BaseURL,
|
||||
Enabled: p.Enabled,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func toConfigProvider(p *domain.Provider) config.Provider {
|
||||
return config.Provider{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
APIKey: p.APIKey,
|
||||
BaseURL: p.BaseURL,
|
||||
Enabled: p.Enabled,
|
||||
}
|
||||
}
|
||||
233
backend/internal/repository/repository_test.go
Normal file
233
backend/internal/repository/repository_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
// 关闭数据库连接以便 TempDir 清理
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
}
|
||||
|
||||
// ============ ProviderRepository 测试 ============
|
||||
|
||||
func TestProviderRepository_Create(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: "test-provider",
|
||||
Name: "Test Provider",
|
||||
APIKey: "sk-test-key",
|
||||
BaseURL: "https://api.test.com",
|
||||
Enabled: true,
|
||||
}
|
||||
err := repo.Create(provider)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestProviderRepository_GetByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: "test-provider", Name: "Test", APIKey: "sk-test-key", BaseURL: "https://api.test.com",
|
||||
}
|
||||
err := repo.Create(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := repo.GetByID("test-provider")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-provider", result.ID)
|
||||
assert.Equal(t, "Test", result.Name)
|
||||
}
|
||||
|
||||
func TestProviderRepository_GetByID_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
_, err := repo.GetByID("nonexistent")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProviderRepository_List(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
for _, id := range []string{"pA", "pB", "pC"} {
|
||||
err := repo.Create(&domain.Provider{ID: id, Name: id, APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
providers, err := repo.List()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, providers, 3)
|
||||
}
|
||||
|
||||
func TestProviderRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})
|
||||
|
||||
err := repo.Update("p1", map[string]interface{}{"name": "New"})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, _ := repo.GetByID("p1")
|
||||
assert.Equal(t, "New", result.Name)
|
||||
}
|
||||
|
||||
func TestProviderRepository_Update_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
err := repo.Update("nonexistent", map[string]interface{}{"name": "New"})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProviderRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
err := repo.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.GetByID("p1")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProviderRepository_Delete_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
err := repo.Delete("nonexistent")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ ModelRepository 测试 ============
|
||||
|
||||
func TestModelRepository_Create(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestModelRepository_GetByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
result, err := repo.GetByID("m1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
assert.Equal(t, "gpt-4", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelRepository_GetByModelName(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
result, err := repo.GetByModelName("gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
}
|
||||
|
||||
func TestModelRepository_List(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})
|
||||
|
||||
all, err := repo.List("")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, all, 3)
|
||||
|
||||
p1Models, err := repo.List("p1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, p1Models, 2)
|
||||
}
|
||||
|
||||
func TestModelRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
err := repo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, _ := repo.GetByID("m1")
|
||||
assert.False(t, result.Enabled)
|
||||
}
|
||||
|
||||
func TestModelRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
|
||||
err := repo.Delete("m1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.GetByID("m1")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ StatsRepository 测试 ============
|
||||
|
||||
func TestStatsRepository_Record(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewStatsRepository(db)
|
||||
|
||||
err := repo.Record("provider-1", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 再次记录应递增
|
||||
err = repo.Record("provider-1", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
stats, err := repo.Query("provider-1", "", nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, stats, 1)
|
||||
assert.Equal(t, 2, stats[0].RequestCount)
|
||||
}
|
||||
|
||||
func TestStatsRepository_Query(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewStatsRepository(db)
|
||||
|
||||
repo.Record("p1", "gpt-4")
|
||||
// 注意:当前 schema 只有 date 字段有唯一约束
|
||||
// 所以同一 provider + model 只能有一条记录
|
||||
stats, err := repo.Query("p1", "", nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
}
|
||||
13
backend/internal/repository/stats_repo.go
Normal file
13
backend/internal/repository/stats_repo.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
// StatsRepository 统计数据仓库接口
|
||||
type StatsRepository interface {
|
||||
Record(providerID, modelName string) error
|
||||
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
|
||||
}
|
||||
79
backend/internal/repository/stats_repo_impl.go
Normal file
79
backend/internal/repository/stats_repo_impl.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type statsRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewStatsRepository(db *gorm.DB) StatsRepository {
|
||||
return &statsRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *statsRepository) Record(providerID, modelName string) error {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
todayTime, _ := time.Parse("2006-01-02", today)
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats config.UsageStats
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, todayTime).First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
stats = config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
Date: todayTime,
|
||||
}
|
||||
return tx.Create(&stats).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
var stats []config.UsageStats
|
||||
query := r.db.Model(&config.UsageStats{})
|
||||
|
||||
if providerID != "" {
|
||||
query = query.Where("provider_id = ?", providerID)
|
||||
}
|
||||
if modelName != "" {
|
||||
query = query.Where("model_name = ?", modelName)
|
||||
}
|
||||
if startDate != nil {
|
||||
query = query.Where("date >= ?", startDate)
|
||||
}
|
||||
if endDate != nil {
|
||||
query = query.Where("date <= ?", endDate)
|
||||
}
|
||||
|
||||
err := query.Order("date DESC").Find(&stats).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]domain.UsageStats, len(stats))
|
||||
for i := range stats {
|
||||
result[i] = domain.UsageStats{
|
||||
ID: stats[i].ID,
|
||||
ProviderID: stats[i].ProviderID,
|
||||
ModelName: stats[i].ModelName,
|
||||
RequestCount: stats[i].RequestCount,
|
||||
Date: stats[i].Date,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrModelNotFound = errors.New("模型未找到")
|
||||
ErrModelDisabled = errors.New("模型已禁用")
|
||||
ErrProviderDisabled = errors.New("供应商已禁用")
|
||||
)
|
||||
|
||||
// RouteResult 路由结果
|
||||
type RouteResult struct {
|
||||
Provider *config.Provider
|
||||
Model *config.Model
|
||||
}
|
||||
|
||||
// Router 模型路由器
|
||||
type Router struct{}
|
||||
|
||||
// NewRouter 创建路由器
|
||||
func NewRouter() *Router {
|
||||
return &Router{}
|
||||
}
|
||||
|
||||
// Route 根据模型名称路由到供应商
|
||||
func (r *Router) Route(modelName string) (*RouteResult, error) {
|
||||
// 查询模型
|
||||
models, err := config.ListModels("")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
|
||||
// 查找匹配的模型
|
||||
var targetModel *config.Model
|
||||
for i := range models {
|
||||
if models[i].ModelName == modelName {
|
||||
targetModel = &models[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetModel == nil {
|
||||
return nil, ErrModelNotFound
|
||||
}
|
||||
|
||||
// 检查模型是否启用
|
||||
if !targetModel.Enabled {
|
||||
return nil, ErrModelDisabled
|
||||
}
|
||||
|
||||
// 查询供应商
|
||||
provider, err := config.GetProvider(targetModel.ProviderID, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询供应商失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查供应商是否启用
|
||||
if !provider.Enabled {
|
||||
return nil, ErrProviderDisabled
|
||||
}
|
||||
|
||||
return &RouteResult{
|
||||
Provider: provider,
|
||||
Model: targetModel,
|
||||
}, nil
|
||||
}
|
||||
12
backend/internal/service/model_service.go
Normal file
12
backend/internal/service/model_service.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
// ModelService 模型服务接口
|
||||
type ModelService interface {
|
||||
Create(model *domain.Model) error
|
||||
Get(id string) (*domain.Model, error)
|
||||
List(providerID string) ([]domain.Model, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
}
|
||||
50
backend/internal/service/model_service_impl.go
Normal file
50
backend/internal/service/model_service_impl.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type modelService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
}
|
||||
|
||||
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
|
||||
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
}
|
||||
|
||||
func (s *modelService) Create(model *domain.Model) error {
|
||||
// Verify provider exists
|
||||
_, err := s.providerRepo.GetByID(model.ProviderID)
|
||||
if err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
model.Enabled = true
|
||||
return s.modelRepo.Create(model)
|
||||
}
|
||||
|
||||
func (s *modelService) Get(id string) (*domain.Model, error) {
|
||||
return s.modelRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *modelService) List(providerID string) ([]domain.Model, error) {
|
||||
return s.modelRepo.List(providerID)
|
||||
}
|
||||
|
||||
func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||
// If updating provider_id, verify new provider exists
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
_, err := s.providerRepo.GetByID(providerID)
|
||||
if err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
}
|
||||
return s.modelRepo.Update(id, updates)
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(id string) error {
|
||||
return s.modelRepo.Delete(id)
|
||||
}
|
||||
12
backend/internal/service/provider_service.go
Normal file
12
backend/internal/service/provider_service.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
// ProviderService 供应商服务接口
|
||||
type ProviderService interface {
|
||||
Create(provider *domain.Provider) error
|
||||
Get(id string, maskKey bool) (*domain.Provider, error)
|
||||
List() ([]domain.Provider, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
}
|
||||
49
backend/internal/service/provider_service_impl.go
Normal file
49
backend/internal/service/provider_service_impl.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type providerService struct {
|
||||
providerRepo repository.ProviderRepository
|
||||
}
|
||||
|
||||
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo}
|
||||
}
|
||||
|
||||
func (s *providerService) Create(provider *domain.Provider) error {
|
||||
provider.Enabled = true
|
||||
return s.providerRepo.Create(provider)
|
||||
}
|
||||
|
||||
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
provider, err := s.providerRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if maskKey {
|
||||
provider.MaskAPIKey()
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (s *providerService) List() ([]domain.Provider, error) {
|
||||
providers, err := s.providerRepo.List()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range providers {
|
||||
providers[i].MaskAPIKey()
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func (s *providerService) Update(id string, updates map[string]interface{}) error {
|
||||
return s.providerRepo.Update(id, updates)
|
||||
}
|
||||
|
||||
func (s *providerService) Delete(id string) error {
|
||||
return s.providerRepo.Delete(id)
|
||||
}
|
||||
8
backend/internal/service/routing_service.go
Normal file
8
backend/internal/service/routing_service.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
// RoutingService 路由服务接口
|
||||
type RoutingService interface {
|
||||
Route(modelName string) (*domain.RouteResult, error)
|
||||
}
|
||||
42
backend/internal/service/routing_service_impl.go
Normal file
42
backend/internal/service/routing_service_impl.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type routingService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
}
|
||||
|
||||
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
|
||||
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
}
|
||||
|
||||
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
model, err := s.modelRepo.GetByModelName(modelName)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
if !model.Enabled {
|
||||
return nil, appErrors.ErrModelDisabled
|
||||
}
|
||||
|
||||
provider, err := s.providerRepo.GetByID(model.ProviderID)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
if !provider.Enabled {
|
||||
return nil, appErrors.ErrProviderDisabled
|
||||
}
|
||||
|
||||
return &domain.RouteResult{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
245
backend/internal/service/service_test.go
Normal file
245
backend/internal/service/service_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
}
|
||||
|
||||
// ============ ProviderService 测试 ============
|
||||
|
||||
func TestProviderService_Create(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
|
||||
}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, provider.Enabled)
|
||||
}
|
||||
|
||||
func TestProviderService_Get_MaskKey(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
|
||||
svc.Create(&domain.Provider{
|
||||
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
|
||||
})
|
||||
|
||||
result, err := svc.Get("p1", true)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "***2345", result.APIKey)
|
||||
|
||||
result, err = svc.Get("p1", false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
|
||||
}
|
||||
|
||||
func TestProviderService_List(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
|
||||
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
|
||||
|
||||
providers, err := svc.List()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, providers, 2)
|
||||
assert.Contains(t, providers[0].APIKey, "***")
|
||||
}
|
||||
|
||||
func TestProviderService_Delete(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
err := svc.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = svc.Get("p1", false)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ ModelService 测试 ============
|
||||
|
||||
func TestModelService_Create(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, model.Enabled)
|
||||
}
|
||||
|
||||
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestModelService_List(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
|
||||
models, err := svc.List("p1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, models, 2)
|
||||
}
|
||||
|
||||
// ============ RoutingService 测试 ============
|
||||
|
||||
func TestRoutingService_Route(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
result, err := svc.Route("gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "p1", result.Provider.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
_, err := svc.Route("nonexistent-model")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
// 先创建启用的模型,然后通过 Update 禁用
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
// 先创建启用的 provider,然后禁用
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ StatsService 测试 ============
|
||||
|
||||
func TestStatsService_RecordAndGet(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
err := svc.Record("p1", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
stats, err := svc.Get("p1", "", nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
|
||||
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
|
||||
}
|
||||
|
||||
result := svc.Aggregate(stats, "provider")
|
||||
assert.Len(t, result, 2)
|
||||
|
||||
p1Count := 0
|
||||
p2Count := 0
|
||||
for _, r := range result {
|
||||
if r["provider_id"] == "p1" {
|
||||
p1Count = r["request_count"].(int)
|
||||
}
|
||||
if r["provider_id"] == "p2" {
|
||||
p2Count = r["request_count"].(int)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 15, p1Count)
|
||||
assert.Equal(t, 8, p2Count)
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
{ProviderID: "p2", RequestCount: 5},
|
||||
}
|
||||
|
||||
result := svc.Aggregate(stats, "date")
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, 15, result[0]["request_count"])
|
||||
}
|
||||
14
backend/internal/service/stats_service.go
Normal file
14
backend/internal/service/stats_service.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
// StatsService 统计服务接口
|
||||
type StatsService interface {
|
||||
Record(providerID, modelName string) error
|
||||
Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
|
||||
Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{}
|
||||
}
|
||||
85
backend/internal/service/stats_service_impl.go
Normal file
85
backend/internal/service/stats_service_impl.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type statsService struct {
|
||||
statsRepo repository.StatsRepository
|
||||
}
|
||||
|
||||
func NewStatsService(statsRepo repository.StatsRepository) StatsService {
|
||||
return &statsService{statsRepo: statsRepo}
|
||||
}
|
||||
|
||||
func (s *statsService) Record(providerID, modelName string) error {
|
||||
return s.statsRepo.Record(providerID, modelName)
|
||||
}
|
||||
|
||||
func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return s.statsRepo.Query(providerID, modelName, startDate, endDate)
|
||||
}
|
||||
|
||||
func (s *statsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
switch groupBy {
|
||||
case "provider":
|
||||
return s.aggregateByProvider(stats)
|
||||
case "model":
|
||||
return s.aggregateByModel(stats)
|
||||
case "date":
|
||||
return s.aggregateByDate(stats)
|
||||
default:
|
||||
return s.aggregateByProvider(stats)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statsService) aggregateByProvider(stats []domain.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
aggregated[stat.ProviderID] += stat.RequestCount
|
||||
}
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for providerID, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": providerID,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *statsService) aggregateByModel(stats []domain.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.ProviderID + "/" + stat.ModelName
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for key, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"provider_id": key[:len(key)/2],
|
||||
"model_name": key[len(key)/2+1:],
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *statsService) aggregateByDate(stats []domain.UsageStats) []map[string]interface{} {
|
||||
aggregated := make(map[string]int)
|
||||
for _, stat := range stats {
|
||||
key := stat.Date.Format("2006-01-02")
|
||||
aggregated[key] += stat.RequestCount
|
||||
}
|
||||
result := make([]map[string]interface{}, 0, len(aggregated))
|
||||
for date, count := range aggregated {
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"request_count": count,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
33
backend/migrations/001_initial_schema.sql
Normal file
33
backend/migrations/001_initial_schema.sql
Normal file
@@ -0,0 +1,33 @@
|
||||
-- +goose Up
|
||||
CREATE TABLE IF NOT EXISTS providers (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
base_url TEXT NOT NULL,
|
||||
enabled INTEGER DEFAULT 1,
|
||||
created_at DATETIME,
|
||||
updated_at DATETIME
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
enabled INTEGER DEFAULT 1,
|
||||
created_at DATETIME,
|
||||
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS usage_stats (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
provider_id TEXT NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
request_count INTEGER DEFAULT 0,
|
||||
date DATE NOT NULL,
|
||||
UNIQUE(provider_id, model_name, date)
|
||||
);
|
||||
|
||||
-- +goose Down
|
||||
DROP TABLE IF EXISTS usage_stats;
|
||||
DROP TABLE IF EXISTS models;
|
||||
DROP TABLE IF EXISTS providers;
|
||||
9
backend/migrations/002_add_indexes.sql
Normal file
9
backend/migrations/002_add_indexes.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
-- +goose Up
|
||||
CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date);
|
||||
|
||||
-- +goose Down
|
||||
DROP INDEX IF EXISTS idx_usage_stats_provider_model_date;
|
||||
DROP INDEX IF EXISTS idx_models_model_name;
|
||||
DROP INDEX IF EXISTS idx_models_provider_id;
|
||||
74
backend/pkg/errors/errors.go
Normal file
74
backend/pkg/errors/errors.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// AppError 结构化应用错误
|
||||
type AppError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
HTTPStatus int `json:"-"`
|
||||
Cause error `json:"-"`
|
||||
Context map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// Error implements error interface
|
||||
func (e *AppError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %s (%v)", e.Code, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error
|
||||
func (e *AppError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// NewAppError creates a new AppError
|
||||
func NewAppError(code, message string, httpStatus int) *AppError {
|
||||
return &AppError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: httpStatus,
|
||||
}
|
||||
}
|
||||
|
||||
// Predefined errors
|
||||
var (
|
||||
ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound)
|
||||
ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound)
|
||||
ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound)
|
||||
ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound)
|
||||
ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest)
|
||||
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
|
||||
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
|
||||
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
|
||||
)
|
||||
|
||||
// AsAppError 尝试将 error 转换为 *AppError
|
||||
func AsAppError(err error) (*AppError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var appErr *AppError
|
||||
if ok := is(err, &appErr); ok {
|
||||
return appErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func is(err error, target interface{}) bool {
|
||||
// 简单的类型断言
|
||||
if e, ok := err.(*AppError); ok {
|
||||
// 直接赋值
|
||||
switch t := target.(type) {
|
||||
case **AppError:
|
||||
*t = e
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
125
backend/pkg/errors/errors_test.go
Normal file
125
backend/pkg/errors/errors_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAppError(t *testing.T) {
|
||||
err := NewAppError("test_code", "测试消息", http.StatusBadRequest)
|
||||
assert.Equal(t, "test_code", err.Code)
|
||||
assert.Equal(t, "测试消息", err.Message)
|
||||
assert.Equal(t, http.StatusBadRequest, err.HTTPStatus)
|
||||
assert.Nil(t, err.Cause)
|
||||
assert.Nil(t, err.Context)
|
||||
}
|
||||
|
||||
func TestAppError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *AppError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "无原因错误",
|
||||
err: NewAppError("code1", "消息1", 400),
|
||||
expected: "code1: 消息1",
|
||||
},
|
||||
{
|
||||
name: "带原因错误",
|
||||
err: Wrap(NewAppError("code2", "消息2", 500), errors.New("原始错误")),
|
||||
expected: "code2: 消息2 (原始错误)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppError_Unwrap(t *testing.T) {
|
||||
cause := errors.New("原始错误")
|
||||
err := Wrap(ErrInternal, cause)
|
||||
assert.Equal(t, cause, err.Unwrap())
|
||||
}
|
||||
|
||||
func TestWrap(t *testing.T) {
|
||||
cause := errors.New("网络超时")
|
||||
wrapped := Wrap(ErrInternal, cause)
|
||||
assert.Equal(t, "internal_error", wrapped.Code)
|
||||
assert.Equal(t, "内部错误", wrapped.Message)
|
||||
assert.Equal(t, http.StatusInternalServerError, wrapped.HTTPStatus)
|
||||
assert.Equal(t, cause, wrapped.Cause)
|
||||
}
|
||||
|
||||
func TestWithContext(t *testing.T) {
|
||||
err := WithContext(ErrModelNotFound, "model", "gpt-4")
|
||||
assert.Equal(t, "model_not_found", err.Code)
|
||||
assert.NotNil(t, err.Context)
|
||||
assert.Equal(t, "gpt-4", err.Context["model"])
|
||||
|
||||
// 测试链式添加上下文
|
||||
err2 := WithContext(err, "provider", "openai")
|
||||
assert.Equal(t, "gpt-4", err2.Context["model"])
|
||||
assert.Equal(t, "openai", err2.Context["provider"])
|
||||
}
|
||||
|
||||
func TestWithMessage(t *testing.T) {
|
||||
err := WithMessage(ErrInvalidRequest, "自定义错误消息")
|
||||
assert.Equal(t, "invalid_request", err.Code)
|
||||
assert.Equal(t, "自定义错误消息", err.Message)
|
||||
assert.Equal(t, http.StatusBadRequest, err.HTTPStatus)
|
||||
}
|
||||
|
||||
func TestPredefinedErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *AppError
|
||||
code string
|
||||
httpStatus int
|
||||
}{
|
||||
{"ErrModelNotFound", ErrModelNotFound, "model_not_found", http.StatusNotFound},
|
||||
{"ErrModelDisabled", ErrModelDisabled, "model_disabled", http.StatusNotFound},
|
||||
{"ErrProviderNotFound", ErrProviderNotFound, "provider_not_found", http.StatusNotFound},
|
||||
{"ErrProviderDisabled", ErrProviderDisabled, "provider_disabled", http.StatusNotFound},
|
||||
{"ErrInvalidRequest", ErrInvalidRequest, "invalid_request", http.StatusBadRequest},
|
||||
{"ErrInternal", ErrInternal, "internal_error", http.StatusInternalServerError},
|
||||
{"ErrDatabaseNotInit", ErrDatabaseNotInit, "database_not_initialized", http.StatusInternalServerError},
|
||||
{"ErrConflict", ErrConflict, "conflict", http.StatusConflict},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.code, tt.err.Code)
|
||||
assert.Equal(t, tt.httpStatus, tt.err.HTTPStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsAppError(t *testing.T) {
|
||||
t.Run("nil输入", func(t *testing.T) {
|
||||
_, ok := AsAppError(nil)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("AppError类型", func(t *testing.T) {
|
||||
appErr, ok := AsAppError(ErrModelNotFound)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, ErrModelNotFound, appErr)
|
||||
})
|
||||
|
||||
t.Run("Wrapped AppError", func(t *testing.T) {
|
||||
wrapped := Wrap(ErrInternal, errors.New("cause"))
|
||||
appErr, ok := AsAppError(wrapped)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "internal_error", appErr.Code)
|
||||
})
|
||||
|
||||
t.Run("非AppError类型", func(t *testing.T) {
|
||||
_, ok := AsAppError(errors.New("普通错误"))
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
42
backend/pkg/errors/wrap.go
Normal file
42
backend/pkg/errors/wrap.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package errors
|
||||
|
||||
// Wrap wraps an error with cause
|
||||
func Wrap(err *AppError, cause error) *AppError {
|
||||
return &AppError{
|
||||
Code: err.Code,
|
||||
Message: err.Message,
|
||||
HTTPStatus: err.HTTPStatus,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext adds context to an AppError
|
||||
func WithContext(err *AppError, key string, value interface{}) *AppError {
|
||||
newErr := &AppError{
|
||||
Code: err.Code,
|
||||
Message: err.Message,
|
||||
HTTPStatus: err.HTTPStatus,
|
||||
Cause: err.Cause,
|
||||
}
|
||||
if err.Context != nil {
|
||||
newErr.Context = make(map[string]interface{})
|
||||
for k, v := range err.Context {
|
||||
newErr.Context[k] = v
|
||||
}
|
||||
} else {
|
||||
newErr.Context = make(map[string]interface{})
|
||||
}
|
||||
newErr.Context[key] = value
|
||||
return newErr
|
||||
}
|
||||
|
||||
// WithMessage creates a new AppError with a custom message
|
||||
func WithMessage(err *AppError, message string) *AppError {
|
||||
return &AppError{
|
||||
Code: err.Code,
|
||||
Message: message,
|
||||
HTTPStatus: err.HTTPStatus,
|
||||
Cause: err.Cause,
|
||||
Context: err.Context,
|
||||
}
|
||||
}
|
||||
17
backend/pkg/logger/context.go
Normal file
17
backend/pkg/logger/context.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package logger
|
||||
|
||||
import "go.uber.org/zap"
|
||||
|
||||
// WithRequestID 向 logger 添加 request_id 字段
|
||||
func WithRequestID(logger *zap.Logger, requestID string) *zap.Logger {
|
||||
return logger.With(zap.String("request_id", requestID))
|
||||
}
|
||||
|
||||
// WithContext 向 logger 添加多个自定义字段
|
||||
func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for k, v := range fields {
|
||||
zapFields = append(zapFields, zap.Any(k, v))
|
||||
}
|
||||
return logger.With(zapFields...)
|
||||
}
|
||||
109
backend/pkg/logger/logger.go
Normal file
109
backend/pkg/logger/logger.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Config 日志配置
|
||||
type Config struct {
|
||||
Level string // 日志级别: debug, info, warn, error
|
||||
Path string // 日志文件目录,为空则仅输出到 stdout
|
||||
MaxSize int // 单个日志文件最大尺寸 (MB)
|
||||
MaxBackups int // 保留的旧日志文件最大数量
|
||||
MaxAge int // 保留旧日志文件的最大天数
|
||||
Compress bool // 是否压缩旧日志文件
|
||||
}
|
||||
|
||||
// New 根据配置创建 zap.Logger
|
||||
// 如果 Path 为空,仅输出到 stdout;
|
||||
// 如果 Path 已设置,同时输出到 stdout 和文件(文件使用 JSON 格式,stdout 使用 console 格式)
|
||||
func New(cfg Config) (*zap.Logger, error) {
|
||||
level, err := parseLevel(cfg.Level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// stdout encoder — console 格式
|
||||
stdoutEncoder := zapcore.NewConsoleEncoder(zapcore.EncoderConfig{
|
||||
TimeKey: "ts",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
FunctionKey: zapcore.OmitKey,
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.CapitalColorLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
})
|
||||
|
||||
stdoutCore := zapcore.NewCore(
|
||||
stdoutEncoder,
|
||||
zapcore.AddSync(os.Stdout),
|
||||
level,
|
||||
)
|
||||
|
||||
// 仅 stdout 模式
|
||||
if cfg.Path == "" {
|
||||
return zap.New(stdoutCore, zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel)), nil
|
||||
}
|
||||
|
||||
// 文件 encoder — JSON 格式
|
||||
fileEncoder := zapcore.NewJSONEncoder(zapcore.EncoderConfig{
|
||||
TimeKey: "ts",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
FunctionKey: zapcore.OmitKey,
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
})
|
||||
|
||||
rotateWriter := newRotateWriter(cfg)
|
||||
fileCore := zapcore.NewCore(
|
||||
fileEncoder,
|
||||
zapcore.AddSync(rotateWriter),
|
||||
level,
|
||||
)
|
||||
|
||||
core := zapcore.NewTee(stdoutCore, fileCore)
|
||||
return zap.New(core, zap.AddCaller(), zap.AddStacktrace(zap.ErrorLevel)), nil
|
||||
}
|
||||
|
||||
// parseLevel 将字符串解析为 zapcore.Level
|
||||
func parseLevel(s string) (zapcore.Level, error) {
|
||||
switch s {
|
||||
case "debug":
|
||||
return zapcore.DebugLevel, nil
|
||||
case "info":
|
||||
return zapcore.InfoLevel, nil
|
||||
case "warn":
|
||||
return zapcore.WarnLevel, nil
|
||||
case "error":
|
||||
return zapcore.ErrorLevel, nil
|
||||
default:
|
||||
return zapcore.InfoLevel, nil
|
||||
}
|
||||
}
|
||||
|
||||
// logFileName 生成当日日志文件名: nex-YYYY-MM-DD.log
|
||||
func logFileName() string {
|
||||
return "nex-" + time.Now().Format("2006-01-02") + ".log"
|
||||
}
|
||||
|
||||
// logFilePath 拼接完整日志文件路径
|
||||
func logFilePath(dir string) string {
|
||||
return filepath.Join(dir, logFileName())
|
||||
}
|
||||
138
backend/pkg/logger/logger_test.go
Normal file
138
backend/pkg/logger/logger_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNew_StdoutOnly(t *testing.T) {
|
||||
logger, err := New(Config{Level: "info"})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, logger)
|
||||
assert.NoError(t, logger.Sync())
|
||||
}
|
||||
|
||||
func TestNew_WithFileOutput(t *testing.T) {
|
||||
dir := filepath.Join(os.TempDir(), "nex-logger-test")
|
||||
os.MkdirAll(dir, 0755)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
logger, err := New(Config{
|
||||
Level: "debug",
|
||||
Path: dir,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 7,
|
||||
Compress: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, logger)
|
||||
|
||||
logger.Info("test log message")
|
||||
_ = logger.Sync()
|
||||
|
||||
// 验证日志文件已创建
|
||||
files, err := os.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, files, "日志目录应包含文件")
|
||||
}
|
||||
|
||||
func TestNew_AllLevels(t *testing.T) {
|
||||
levels := []string{"debug", "info", "warn", "error"}
|
||||
for _, level := range levels {
|
||||
logger, err := New(Config{Level: level})
|
||||
assert.NoError(t, err, "级别 %s 应有效", level)
|
||||
assert.NotNil(t, logger)
|
||||
assert.NoError(t, logger.Sync())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_EmptyLevel(t *testing.T) {
|
||||
// 空级别应默认为 info
|
||||
logger, err := New(Config{Level: ""})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, logger)
|
||||
assert.NoError(t, logger.Sync())
|
||||
}
|
||||
|
||||
func TestNew_InvalidPath(t *testing.T) {
|
||||
// 不可写的路径
|
||||
logger, err := New(Config{
|
||||
Level: "info",
|
||||
Path: "/nonexistent/deeply/nested/path/logs",
|
||||
})
|
||||
// 应能创建 logger(错误在写入时发生)
|
||||
// 实际上 lumberjack 会尝试创建目录
|
||||
_ = logger
|
||||
_ = err
|
||||
}
|
||||
|
||||
func TestParseLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
valid bool
|
||||
}{
|
||||
{"debug", true},
|
||||
{"info", true},
|
||||
{"warn", true},
|
||||
{"error", true},
|
||||
{"", true}, // 默认为 info
|
||||
{"invalid", true}, // 默认为 info
|
||||
}
|
||||
for _, tt := range tests {
|
||||
_, err := parseLevel(tt.input)
|
||||
assert.NoError(t, err, "parseLevel(%q) 不应报错", tt.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogFilePath(t *testing.T) {
|
||||
result := logFilePath(filepath.Join("var", "log"))
|
||||
assert.Contains(t, result, "nex-")
|
||||
assert.Contains(t, result, ".log")
|
||||
}
|
||||
|
||||
func TestLogFileName(t *testing.T) {
|
||||
name := logFileName()
|
||||
assert.Contains(t, name, "nex-")
|
||||
assert.Contains(t, name, ".log")
|
||||
assert.Len(t, name, len("nex-2006-01-02.log"))
|
||||
}
|
||||
|
||||
func TestNewRotateWriter_Defaults(t *testing.T) {
|
||||
cfg := Config{
|
||||
Path: t.TempDir(),
|
||||
MaxSize: 0,
|
||||
MaxAge: 0,
|
||||
Compress: true,
|
||||
}
|
||||
writer := newRotateWriter(cfg)
|
||||
require.NotNil(t, writer)
|
||||
assert.Equal(t, 100, writer.MaxSize)
|
||||
assert.Equal(t, 10, writer.MaxBackups)
|
||||
assert.Equal(t, 30, writer.MaxAge)
|
||||
}
|
||||
|
||||
func TestWithRequestID(t *testing.T) {
|
||||
logger, err := New(Config{Level: "info"})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLogger := WithRequestID(logger, "test-request-123")
|
||||
assert.NotNil(t, contextLogger)
|
||||
assert.IsType(t, &zap.Logger{}, contextLogger)
|
||||
}
|
||||
|
||||
func TestWithContext(t *testing.T) {
|
||||
logger, err := New(Config{Level: "info"})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLogger := WithContext(logger, map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
})
|
||||
assert.NotNil(t, contextLogger)
|
||||
}
|
||||
30
backend/pkg/logger/rotate.go
Normal file
30
backend/pkg/logger/rotate.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package logger
|
||||
|
||||
import "gopkg.in/lumberjack.v2"
|
||||
|
||||
// newRotateWriter 根据配置创建 lumberjack.Logger 作为日志轮转写入器
|
||||
// 日志文件位于 cfg.Path 目录下,文件名格式为 nex-YYYY-MM-DD.log
|
||||
func newRotateWriter(cfg Config) *lumberjack.Logger {
|
||||
maxSize := cfg.MaxSize
|
||||
if maxSize <= 0 {
|
||||
maxSize = 100
|
||||
}
|
||||
|
||||
maxBackups := cfg.MaxBackups
|
||||
if maxBackups <= 0 {
|
||||
maxBackups = 10
|
||||
}
|
||||
|
||||
maxAge := cfg.MaxAge
|
||||
if maxAge <= 0 {
|
||||
maxAge = 30
|
||||
}
|
||||
|
||||
return &lumberjack.Logger{
|
||||
Filename: logFilePath(cfg.Path),
|
||||
MaxSize: maxSize, // MB
|
||||
MaxBackups: maxBackups,
|
||||
MaxAge: maxAge, // days
|
||||
Compress: cfg.Compress,
|
||||
}
|
||||
}
|
||||
22
backend/pkg/validator/validator.go
Normal file
22
backend/pkg/validator/validator.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// validate 全局验证器实例
|
||||
var validate *validator.Validate
|
||||
|
||||
func init() {
|
||||
validate = validator.New(validator.WithRequiredStructEnabled())
|
||||
}
|
||||
|
||||
// Get 返回全局验证器实例
|
||||
func Get() *validator.Validate {
|
||||
return validate
|
||||
}
|
||||
|
||||
// Validate 验证结构体
|
||||
func Validate(s interface{}) error {
|
||||
return validate.Struct(s)
|
||||
}
|
||||
45
backend/pkg/validator/validator_test.go
Normal file
45
backend/pkg/validator/validator_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestStruct struct {
|
||||
Name string `validate:"required"`
|
||||
Email string `validate:"required,email"`
|
||||
Age int `validate:"min=0,max=150"`
|
||||
}
|
||||
|
||||
func TestValidate_ValidStruct(t *testing.T) {
|
||||
s := TestStruct{Name: "John", Email: "john@example.com", Age: 25}
|
||||
err := Validate(s)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidate_MissingRequired(t *testing.T) {
|
||||
s := TestStruct{Email: "john@example.com", Age: 25}
|
||||
err := Validate(s)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidate_InvalidEmail(t *testing.T) {
|
||||
s := TestStruct{Name: "John", Email: "not-an-email", Age: 25}
|
||||
err := Validate(s)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidate_AgeOutOfRange(t *testing.T) {
|
||||
s := TestStruct{Name: "John", Email: "john@example.com", Age: 200}
|
||||
err := Validate(s)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGet_ReturnsInstance(t *testing.T) {
|
||||
v := Get()
|
||||
assert.NotNil(t, v)
|
||||
// 多次调用应返回相同实例
|
||||
v2 := Get()
|
||||
assert.Equal(t, v, v2)
|
||||
}
|
||||
74
backend/tests/helpers.go
Normal file
74
backend/tests/helpers.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SetupTestDB initializes a temporary SQLite database with auto-migration.
|
||||
func SetupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
dsn := dir + "/test.db"
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
assert.NoError(t, err, "failed to open test database")
|
||||
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
assert.NoError(t, err, "failed to auto-migrate test database")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// CleanupTestDB closes the database and removes the temp database file.
|
||||
func CleanupTestDB(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
assert.NoError(t, err, "failed to get underlying sql.DB")
|
||||
|
||||
err = sqlDB.Close()
|
||||
assert.NoError(t, err, "failed to close test database")
|
||||
}
|
||||
|
||||
// CreateTestProvider creates a test provider and returns it.
|
||||
func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
|
||||
t.Helper()
|
||||
|
||||
provider := config.Provider{
|
||||
ID: id,
|
||||
Name: fmt.Sprintf("test-provider-%s", id),
|
||||
APIKey: fmt.Sprintf("test-api-key-%s", id),
|
||||
BaseURL: fmt.Sprintf("https://api.test-%s.com", id),
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err := db.Create(&provider).Error
|
||||
assert.NoError(t, err, "failed to create test provider")
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
// CreateTestModel creates a test model and returns it.
|
||||
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) config.Model {
|
||||
t.Helper()
|
||||
|
||||
model := config.Model{
|
||||
ID: id,
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err := db.Create(&model).Error
|
||||
assert.NoError(t, err, "failed to create test model")
|
||||
|
||||
return model
|
||||
}
|
||||
263
backend/tests/integration/integration_test.go
Normal file
263
backend/tests/integration/integration_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
_ = service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
providers.POST("", providerHandler.CreateProvider)
|
||||
providers.GET("/:id", providerHandler.GetProvider)
|
||||
providers.PUT("/:id", providerHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
models.POST("", modelHandler.CreateModel)
|
||||
models.GET("/:id", modelHandler.GetModel)
|
||||
models.PUT("/:id", modelHandler.UpdateModel)
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
return r, db
|
||||
}
|
||||
|
||||
func TestOpenAI_CompleteFlow(t *testing.T) {
|
||||
r, _ := setupIntegrationTest(t)
|
||||
|
||||
// 1. 创建 Provider
|
||||
providerBody, _ := json.Marshal(map[string]string{
|
||||
"id": "openai", "name": "OpenAI", "api_key": "sk-test-key", "base_url": "https://api.openai.com/v1",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
// 2. 创建 Model
|
||||
modelBody, _ := json.Marshal(map[string]string{
|
||||
"id": "gpt4", "provider_id": "openai", "model_name": "gpt-4",
|
||||
})
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
// 3. 列出 Provider
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/providers", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var providers []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &providers)
|
||||
assert.Len(t, providers, 1)
|
||||
assert.Contains(t, providers[0].APIKey, "***") // 已掩码
|
||||
|
||||
// 4. 列出 Model
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/models?provider_id=openai", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
var models []domain.Model
|
||||
json.Unmarshal(w.Body.Bytes(), &models)
|
||||
assert.Len(t, models, 1)
|
||||
assert.Equal(t, "gpt-4", models[0].ModelName)
|
||||
|
||||
// 5. 更新 Provider
|
||||
updateBody, _ := json.Marshal(map[string]string{"name": "OpenAI Updated"})
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("PUT", "/api/providers/openai", bytes.NewReader(updateBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
// 6. 删除 Model
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("DELETE", "/api/models/gpt4", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 204, w.Code)
|
||||
|
||||
// 7. 删除 Provider
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("DELETE", "/api/providers/openai", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 204, w.Code)
|
||||
}
|
||||
|
||||
func TestAnthropic_ModelCreation(t *testing.T) {
|
||||
r, _ := setupIntegrationTest(t)
|
||||
|
||||
// 创建 Provider 和 Model 用于 Anthropic 代理
|
||||
providerBody, _ := json.Marshal(map[string]string{
|
||||
"id": "anthropic", "name": "Anthropic", "api_key": "sk-ant-test", "base_url": "https://api.anthropic.com/v1",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
modelBody, _ := json.Marshal(map[string]string{
|
||||
"id": "claude3", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229",
|
||||
})
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
// 验证创建成功
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/models/claude3", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestStats_RecordingAndQuery(t *testing.T) {
|
||||
r, db := setupIntegrationTest(t)
|
||||
|
||||
// 创建 Provider 和 Model
|
||||
providerBody, _ := json.Marshal(map[string]string{
|
||||
"id": "p1", "name": "Provider1", "api_key": "key", "base_url": "https://test.com",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
modelBody, _ := json.Marshal(map[string]string{
|
||||
"id": "m1", "provider_id": "p1", "model_name": "gpt-4",
|
||||
})
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// 直接通过 repository 记录统计(模拟代理请求后的统计记录)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
statsRepo.Record("p1", "gpt-4")
|
||||
statsRepo.Record("p1", "gpt-4")
|
||||
statsRepo.Record("p1", "gpt-4")
|
||||
|
||||
// 查询统计
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/stats?provider_id=p1", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var stats []domain.UsageStats
|
||||
json.Unmarshal(w.Body.Bytes(), &stats)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 3, stats[0].RequestCount)
|
||||
|
||||
// 聚合统计
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestProvider_DuplicateCreation(t *testing.T) {
|
||||
r, _ := setupIntegrationTest(t)
|
||||
|
||||
providerBody, _ := json.Marshal(map[string]string{
|
||||
"id": "p1", "name": "P1", "api_key": "key", "base_url": "https://test.com",
|
||||
})
|
||||
|
||||
// 第一次创建成功
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
// 第二次创建应失败(UNIQUE 约束)
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(providerBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
}
|
||||
|
||||
func TestProvider_NotFound(t *testing.T) {
|
||||
r, _ := setupIntegrationTest(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/api/providers/nonexistent", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestStats_InvalidDate(t *testing.T) {
|
||||
r, _ := setupIntegrationTest(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/api/stats?start_date=not-a-date", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
// Suppress unused import warning
|
||||
var _ = time.Second
|
||||
Reference in New Issue
Block a user