1
0

feat: 实现分层架构,包含 domain、service、repository 和 pkg 层

- 新增 domain 层:model、provider、route、stats 实体
- 新增 service 层:models、providers、routing、stats 业务逻辑
- 新增 repository 层:models、providers、stats 数据访问
- 新增 pkg 工具包:errors、logger、validator
- 新增中间件:CORS、logging、recovery、request ID
- 新增数据库迁移:初始 schema 和索引
- 新增单元测试和集成测试
- 新增规范文档:config-management、database-migration、error-handling、layered-architecture、middleware-system、request-validation、structured-logging、test-coverage
- 移除 config 子包和 model_router(已迁移至分层架构)
This commit is contained in:
2026-04-16 00:47:20 +08:00
parent 915b004924
commit f18904af1e
77 changed files with 5727 additions and 1257 deletions

View File

@@ -1,24 +1,87 @@
package config
import (
"fmt"
"os"
"path/filepath"
"time"
"gopkg.in/yaml.v3"
appErrors "nex/backend/pkg/errors"
)
// Config 应用配置
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Log LogConfig `yaml:"log"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Port int `yaml:"port"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Path string `yaml:"path"`
MaxIdleConns int `yaml:"max_idle_conns"`
MaxOpenConns int `yaml:"max_open_conns"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"`
}
// LogConfig 日志配置
type LogConfig struct {
Level string `yaml:"level"`
Path string `yaml:"path"`
MaxSize int `yaml:"max_size"`
MaxBackups int `yaml:"max_backups"`
MaxAge int `yaml:"max_age"`
Compress bool `yaml:"compress"`
}
// DefaultConfig returns default config values
func DefaultConfig() *Config {
// Use home dir for default paths
homeDir, _ := os.UserHomeDir()
nexDir := filepath.Join(homeDir, ".nex")
return &Config{
Server: ServerConfig{
Port: 9826,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
},
Database: DatabaseConfig{
Path: filepath.Join(nexDir, "config.db"),
MaxIdleConns: 10,
MaxOpenConns: 100,
ConnMaxLifetime: 1 * time.Hour,
},
Log: LogConfig{
Level: "info",
Path: filepath.Join(nexDir, "log"),
MaxSize: 100,
MaxBackups: 10,
MaxAge: 30,
Compress: true,
},
}
}
// GetConfigDir 获取配置目录路径(~/.nex/
func GetConfigDir() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
// 确保目录存在
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
}
return configDir, nil
}
@@ -30,3 +93,79 @@ func GetDBPath() (string, error) {
}
return filepath.Join(configDir, "config.db"), nil
}
// GetConfigPath 获取配置文件路径
func GetConfigPath() (string, error) {
configDir, err := GetConfigDir()
if err != nil {
return "", err
}
return filepath.Join(configDir, "config.yaml"), nil
}
// LoadConfig loads config from YAML file, creates default if not exists
func LoadConfig() (*Config, error) {
configPath, err := GetConfigPath()
if err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
cfg := DefaultConfig()
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
// Create default config file
if saveErr := SaveConfig(cfg); saveErr != nil {
return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败")
}
return cfg, nil
}
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
}
return cfg, nil
}
// SaveConfig saves config to YAML file
func SaveConfig(cfg *Config) error {
configPath, err := GetConfigPath()
if err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
data, err := yaml.Marshal(cfg)
if err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
// Ensure directory exists
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return os.WriteFile(configPath, data, 0644)
}
// Validate validates the config
func (c *Config) Validate() error {
if c.Server.Port < 1 || c.Server.Port > 65535 {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port))
}
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
if !validLevels[c.Log.Level] {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level))
}
if c.Database.Path == "" {
return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空")
}
return nil
}

View File

@@ -0,0 +1,176 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
require.NotNil(t, cfg)
assert.Equal(t, 9826, cfg.Server.Port)
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
assert.Equal(t, "info", cfg.Log.Level)
assert.Equal(t, 100, cfg.Log.MaxSize)
assert.Equal(t, 10, cfg.Log.MaxBackups)
assert.Equal(t, 30, cfg.Log.MaxAge)
assert.Equal(t, true, cfg.Log.Compress)
}
func TestConfig_Validate(t *testing.T) {
tests := []struct {
name string
modify func(*Config)
wantErr bool
errMsg string
}{
{
name: "默认配置有效",
modify: func(c *Config) {},
wantErr: false,
},
{
name: "端口号为0无效",
modify: func(c *Config) { c.Server.Port = 0 },
wantErr: true,
errMsg: "无效的端口号",
},
{
name: "端口号超出范围无效",
modify: func(c *Config) { c.Server.Port = 70000 },
wantErr: true,
errMsg: "无效的端口号",
},
{
name: "端口号为1有效",
modify: func(c *Config) { c.Server.Port = 1 },
wantErr: false,
},
{
name: "端口号为65535有效",
modify: func(c *Config) { c.Server.Port = 65535 },
wantErr: false,
},
{
name: "无效日志级别",
modify: func(c *Config) { c.Log.Level = "invalid" },
wantErr: true,
errMsg: "无效的日志级别",
},
{
name: "debug级别有效",
modify: func(c *Config) { c.Log.Level = "debug" },
wantErr: false,
},
{
name: "warn级别有效",
modify: func(c *Config) { c.Log.Level = "warn" },
wantErr: false,
},
{
name: "error级别有效",
modify: func(c *Config) { c.Log.Level = "error" },
wantErr: false,
},
{
name: "数据库路径为空无效",
modify: func(c *Config) { c.Database.Path = "" },
wantErr: true,
errMsg: "数据库路径不能为空",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := DefaultConfig()
tt.modify(cfg)
err := cfg.Validate()
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
} else {
assert.NoError(t, err)
}
})
}
}
func TestGetConfigDir(t *testing.T) {
dir, err := GetConfigDir()
require.NoError(t, err)
assert.NotEmpty(t, dir)
assert.Contains(t, dir, ".nex")
}
func TestGetDBPath(t *testing.T) {
path, err := GetDBPath()
require.NoError(t, err)
assert.NotEmpty(t, path)
assert.Contains(t, path, "config.db")
}
func TestGetConfigPath(t *testing.T) {
path, err := GetConfigPath()
require.NoError(t, err)
assert.NotEmpty(t, path)
assert.Contains(t, path, "config.yaml")
}
func TestSaveAndLoadConfig(t *testing.T) {
// 使用临时目录覆盖配置路径
dir := t.TempDir()
cfg := &Config{
Server: ServerConfig{
Port: 9999,
ReadTimeout: 10 * time.Second,
WriteTimeout: 20 * time.Second,
},
Database: DatabaseConfig{
Path: filepath.Join(dir, "test.db"),
MaxIdleConns: 5,
MaxOpenConns: 50,
ConnMaxLifetime: 30 * time.Minute,
},
Log: LogConfig{
Level: "debug",
Path: filepath.Join(dir, "log"),
MaxSize: 50,
MaxBackups: 5,
MaxAge: 7,
Compress: false,
},
}
// 保存配置
configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg)
require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644)
require.NoError(t, err)
// 加载配置
data, err = os.ReadFile(configPath)
require.NoError(t, err)
loaded := &Config{}
err = yaml.Unmarshal(data, loaded)
require.NoError(t, err)
assert.Equal(t, cfg.Server.Port, loaded.Server.Port)
assert.Equal(t, cfg.Log.Level, loaded.Log.Level)
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
}

View File

@@ -1,58 +0,0 @@
package config
import (
"fmt"
"log"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var db *gorm.DB
// InitDB 初始化数据库连接并创建表
func InitDB() error {
dbPath, err := GetDBPath()
if err != nil {
return fmt.Errorf("获取数据库路径失败: %w", err)
}
// 打开数据库连接
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
// 启用 WAL 模式以提升并发性能
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
// 自动迁移表结构
if err := db.AutoMigrate(&Provider{}, &Model{}, &UsageStats{}); err != nil {
return fmt.Errorf("创建表失败: %w", err)
}
log.Printf("数据库初始化成功: %s", dbPath)
return nil
}
// GetDB 获取数据库连接
func GetDB() *gorm.DB {
return db
}
// CloseDB 关闭数据库连接
func CloseDB() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}

View File

@@ -1,119 +0,0 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateModel 创建模型
func CreateModel(model *Model) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 验证供应商是否存在
var provider Provider
err := db.First(&provider, "id = ?", model.ProviderID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
model.CreatedAt = time.Now()
return db.Create(model).Error
}
// GetModel 获取模型
func GetModel(id string) (*Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var model Model
err := db.First(&model, "id = ?", id).Error
if err != nil {
return nil, err
}
return &model, nil
}
// ListModels 列出模型
func ListModels(providerID string) ([]Model, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var models []Model
var err error
if providerID != "" {
err = db.Where("provider_id = ?", providerID).Find(&models).Error
} else {
err = db.Find(&models).Error
}
if err != nil {
return nil, err
}
return models, nil
}
// UpdateModel 更新模型
func UpdateModel(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
// 如果更新了 provider_id验证新供应商是否存在
if providerID, ok := updates["provider_id"].(string); ok {
var provider Provider
err := db.First(&provider, "id = ?", providerID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("供应商不存在")
}
return err
}
}
result := db.Model(&Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteModel 删除模型
func DeleteModel(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Model{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -1,102 +0,0 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// CreateProvider 创建供应商
func CreateProvider(provider *Provider) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
provider.CreatedAt = time.Now()
provider.UpdatedAt = time.Now()
return db.Create(provider).Error
}
// GetProvider 获取供应商
func GetProvider(id string, maskKey bool) (*Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var provider Provider
err := db.First(&provider, "id = ?", id).Error
if err != nil {
return nil, err
}
if maskKey {
provider.MaskAPIKey()
}
return &provider, nil
}
// ListProviders 列出所有供应商
func ListProviders() ([]Provider, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var providers []Provider
err := db.Find(&providers).Error
if err != nil {
return nil, err
}
// 掩码所有 API Key
for i := range providers {
providers[i].MaskAPIKey()
}
return providers, nil
}
// UpdateProvider 更新供应商
func UpdateProvider(id string, updates map[string]interface{}) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
updates["updated_at"] = time.Now()
result := db.Model(&Provider{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}
// DeleteProvider 删除供应商
func DeleteProvider(id string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
result := db.Delete(&Provider{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -1,79 +0,0 @@
package config
import (
"errors"
"time"
"gorm.io/gorm"
)
// RecordRequest 记录请求统计
func RecordRequest(providerID, modelName string) error {
db := GetDB()
if db == nil {
return errors.New("数据库未初始化")
}
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
// 使用事务确保并发安全
return db.Transaction(func(tx *gorm.DB) error {
var stats UsageStats
// 查找或创建统计记录
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).
First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// 创建新记录
stats = UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
}
// 更新计数
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
}
// GetStats 查询统计
func GetStats(providerID, modelName string, startDate, endDate *time.Time) ([]UsageStats, error) {
db := GetDB()
if db == nil {
return nil, errors.New("数据库未初始化")
}
var stats []UsageStats
query := db.Model(&UsageStats{})
if providerID != "" {
query = query.Where("provider_id = ?", providerID)
}
if modelName != "" {
query = query.Where("model_name = ?", modelName)
}
if startDate != nil {
query = query.Where("date >= ?", startDate)
}
if endDate != nil {
query = query.Where("date <= ?", endDate)
}
err := query.Order("date DESC").Find(&stats).Error
if err != nil {
return nil, err
}
return stats, nil
}

View File

@@ -0,0 +1,12 @@
package domain
import "time"
// Model 模型领域模型
type Model struct {
ID string `json:"id"`
ProviderID string `json:"provider_id"`
ModelName string `json:"model_name"`
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at"`
}

View File

@@ -0,0 +1,23 @@
package domain
import "time"
// Provider 供应商领域模型
type Provider struct {
ID string `json:"id"`
Name string `json:"name"`
APIKey string `json:"api_key"`
BaseURL string `json:"base_url"`
Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// MaskAPIKey 掩码 API Key仅显示最后 4 个字符)
func (p *Provider) MaskAPIKey() {
if len(p.APIKey) > 4 {
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
} else {
p.APIKey = "***"
}
}

View File

@@ -0,0 +1,7 @@
package domain
// RouteResult 路由结果
type RouteResult struct {
Provider *Provider
Model *Model
}

View File

@@ -0,0 +1,12 @@
package domain
import "time"
// UsageStats 用量统计领域模型
type UsageStats struct {
ID uint `json:"id"`
ProviderID string `json:"provider_id"`
ModelName string `json:"model_name"`
RequestCount int `json:"request_count"`
Date time.Time `json:"date"`
}

View File

@@ -7,30 +7,33 @@ import (
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/protocol/anthropic"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
"nex/backend/internal/service"
)
// AnthropicHandler Anthropic 协议处理器
type AnthropicHandler struct {
client *provider.Client
router *router.Router
client provider.ProviderClient
routingService service.RoutingService
statsService service.StatsService
}
// NewAnthropicHandler 创建 Anthropic 处理器
func NewAnthropicHandler() *AnthropicHandler {
func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler {
return &AnthropicHandler{
client: provider.NewClient(),
router: router.NewRouter(),
client: client,
routingService: routingService,
statsService: statsService,
}
}
// HandleMessages 处理 Messages 请求
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
// 解析 Anthropic 请求
var req anthropic.MessagesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
@@ -43,7 +46,19 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
return
}
// 检查多模态内容
// 请求验证
if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil {
errMsg := formatValidationErrors(validationErrors)
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "invalid_request_error",
Message: errMsg,
},
})
return
}
if err := h.checkMultimodalContent(&req); err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
Type: "error",
@@ -55,7 +70,6 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
return
}
// 转换为 OpenAI 请求
openaiReq, err := anthropic.ConvertRequest(&req)
if err != nil {
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
@@ -68,14 +82,12 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
return
}
// 路由到供应商
routeResult, err := h.router.Route(openaiReq.Model)
routeResult, err := h.routingService.Route(openaiReq.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, openaiReq, routeResult)
} else {
@@ -83,9 +95,7 @@ func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
}
}
// handleNonStreamRequest 处理非流式请求
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
@@ -98,7 +108,6 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
return
}
// 转换为 Anthropic 响应
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
@@ -111,18 +120,14 @@ func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *ope
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
}()
// 返回响应
c.JSON(http.StatusOK, anthropicResp)
}
// handleStreamRequest 处理流式请求
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
@@ -135,24 +140,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 创建流式转换器
converter := anthropic.NewStreamConverter(
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
openaiReq.Model,
)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
fmt.Printf("流错误: %v\n", event.Error)
break
}
@@ -160,25 +160,19 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
break
}
// 解析 OpenAI 流块
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
if err != nil {
fmt.Printf("解析流块失败: %v\n", err)
continue
}
// 转换为 Anthropic 事件
anthropicEvents, err := converter.ConvertChunk(chunk)
if err != nil {
fmt.Printf("转换事件失败: %v\n", err)
continue
}
// 写入事件
for _, ae := range anthropicEvents {
eventStr, err := anthropic.SerializeEvent(ae)
if err != nil {
fmt.Printf("序列化事件失败: %v\n", err)
continue
}
writer.WriteString(eventStr)
@@ -186,13 +180,11 @@ func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai
}
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, openaiReq.Model)
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
}()
}
// checkMultimodalContent 检查多模态内容
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
for _, msg := range req.Messages {
for _, block := range msg.Content {
@@ -204,40 +196,22 @@ func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest
return nil
}
// handleError 处理路由错误
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
if appErr, ok := appErrors.AsAppError(err); ok {
c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型未找到",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "模型已禁用",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "not_found_error",
Message: "供应商已禁用",
},
})
default:
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "internal_error",
Message: "内部错误: " + err.Error(),
Message: appErr.Message,
},
})
return
}
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
Type: "error",
Error: anthropic.ErrorDetail{
Type: "internal_error",
Message: "内部错误: " + err.Error(),
},
})
}

View File

@@ -0,0 +1,290 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/domain"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
)
func init() {
gin.SetMode(gin.TestMode)
}
// ============ Mock 实现 ============
type mockRoutingService struct {
result *domain.RouteResult
err error
}
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockStatsService struct {
err error
stats []domain.UsageStats
aggrResult []map[string]interface{}
}
func (m *mockStatsService) Record(providerID, modelName string) error {
return m.err
}
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return m.stats, nil
}
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
return m.aggrResult
}
type mockProviderService struct {
provider *domain.Provider
providers []domain.Provider
err error
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err
}
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockProviderService) Delete(id string) error { return m.err }
type mockModelService struct {
model *domain.Model
models []domain.Model
err error
}
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err
}
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockModelService) Delete(id string) error { return m.err }
type mockProviderClient struct {
resp *openai.ChatCompletionResponse
eventChan chan provider.StreamEvent
err error
}
func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
return m.resp, m.err
}
func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) {
return m.eventChan, m.err
}
// ============ OpenAI Handler 测试 ============
func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) {
h := NewOpenAIHandler(nil, nil, nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid")))
h.HandleChatCompletions(c)
assert.Equal(t, 400, w.Code)
}
func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) {
h := NewOpenAIHandler(nil, nil, nil)
// 缺少 model 字段
body, _ := json.Marshal(map[string]interface{}{
"messages": []map[string]string{{"role": "user", "content": "hi"}},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.HandleChatCompletions(c)
assert.Equal(t, 400, w.Code)
}
func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) {
routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound}
h := NewOpenAIHandler(nil, routingSvc, nil)
body, _ := json.Marshal(map[string]interface{}{
"model": "nonexistent",
"messages": []map[string]string{{"role": "user", "content": "hi"}},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.HandleChatCompletions(c)
assert.Equal(t, 404, w.Code)
}
// ============ Provider Handler 测试 ============
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
body, _ := json.Marshal(map[string]string{"id": "p1"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 400, w.Code)
}
func TestProviderHandler_ListProviders(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "P1"},
{ID: "p2", Name: "P2"},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/providers", nil)
h.ListProviders(c)
assert.Equal(t, 200, w.Code)
var result []domain.Provider
json.Unmarshal(w.Body.Bytes(), &result)
assert.Len(t, result, 2)
}
func TestProviderHandler_GetProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("GET", "/api/providers/p1", nil)
h.GetProvider(c)
assert.Equal(t, 200, w.Code)
}
// ============ Model Handler 测试 ============
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{"id": "m1"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
}
func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{
models: []domain.Model{
{ID: "m1", ModelName: "gpt-4"},
{ID: "m2", ModelName: "gpt-3.5"},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/models", nil)
h.ListModels(c)
assert.Equal(t, 200, w.Code)
}
// ============ Stats Handler 测试 ============
func TestStatsHandler_GetStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/stats", nil)
h.GetStats(c)
assert.Equal(t, 200, w.Code)
}
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
h := NewStatsHandler(&mockStatsService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/stats?start_date=invalid", nil)
h.GetStats(c)
assert.Equal(t, 400, w.Code)
}
func TestStatsHandler_AggregateStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
},
aggrResult: []map[string]interface{}{
{"provider_id": "p1", "request_count": 10},
},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/api/stats/aggregate?group_by=provider", nil)
h.AggregateStats(c)
assert.Equal(t, 200, w.Code)
}
// ============ writeError 测试 ============
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/", nil)
writeError(c, appErrors.ErrModelNotFound)
assert.Equal(t, 404, w.Code)
}
func TestFormatValidationErrors(t *testing.T) {
errs := map[string]string{
"model": "模型名称不能为空",
"messages": "消息列表不能为空",
}
result := formatValidationErrors(errs)
require.Contains(t, result, "请求验证失败")
require.Contains(t, result, "model")
require.Contains(t, result, "messages")
}

View File

@@ -0,0 +1,21 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Request-ID")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@@ -0,0 +1,40 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Logging 日志中间件
func Logging(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
requestID, _ := c.Get(RequestIDKey)
logger.Info("请求开始",
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("client_ip", c.ClientIP()),
zap.Any("request_id", requestID),
)
c.Next()
latency := time.Since(start)
statusCode := c.Writer.Status()
logger.Info("请求结束",
zap.Int("status", statusCode),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.Duration("latency", latency),
zap.Int("body_size", c.Writer.Size()),
zap.Any("request_id", requestID),
)
}
}

View File

@@ -0,0 +1,130 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestRequestID_GeneratesUUID(t *testing.T) {
r := gin.New()
r.Use(RequestID())
r.GET("/test", func(c *gin.Context) {
id, exists := c.Get(RequestIDKey)
assert.True(t, exists)
assert.NotEmpty(t, id)
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.NotEmpty(t, w.Header().Get("X-Request-ID"))
}
func TestRequestID_UsesExistingHeader(t *testing.T) {
r := gin.New()
r.Use(RequestID())
r.GET("/test", func(c *gin.Context) {
id, _ := c.Get(RequestIDKey)
assert.Equal(t, "existing-id-123", id)
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Request-ID", "existing-id-123")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "existing-id-123", w.Header().Get("X-Request-ID"))
}
func TestLogging(t *testing.T) {
logger := zap.NewNop()
r := gin.New()
r.Use(Logging(logger))
r.GET("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test?key=value", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestRecovery_NoPanic(t *testing.T) {
logger := zap.NewNop()
r := gin.New()
r.Use(Recovery(logger))
r.GET("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestRecovery_WithPanic(t *testing.T) {
logger := zap.NewNop()
r := gin.New()
r.Use(Recovery(logger))
r.GET("/test", func(c *gin.Context) {
panic("test panic")
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 500, w.Code)
}
func TestCORS_NormalRequest(t *testing.T) {
r := gin.New()
r.Use(CORS())
r.GET("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "GET")
assert.Contains(t, w.Header().Get("Access-Control-Allow-Methods"), "POST")
}
func TestCORS_PreflightRequest(t *testing.T) {
r := gin.New()
r.Use(CORS())
r.OPTIONS("/test", func(c *gin.Context) {
c.Status(200)
})
w := httptest.NewRecorder()
req := httptest.NewRequest("OPTIONS", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
}

View File

@@ -0,0 +1,29 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Recovery 错误恢复中间件
func Recovery(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
requestID, _ := c.Get(RequestIDKey)
logger.Error("panic recovered",
zap.Any("error", err),
zap.Any("request_id", requestID),
zap.String("path", c.Request.URL.Path),
zap.Stack("stack"),
)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"error": "内部错误",
})
}
}()
c.Next()
}
}

View File

@@ -0,0 +1,21 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const RequestIDKey = "request_id"
// RequestID 请求 ID 中间件
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
c.Set(RequestIDKey, requestID)
c.Header("X-Request-ID", requestID)
c.Next()
}
}

View File

@@ -6,15 +6,20 @@ import (
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ModelHandler 模型管理处理器
type ModelHandler struct{}
type ModelHandler struct {
modelService service.ModelService
}
// NewModelHandler 创建模型处理器
func NewModelHandler() *ModelHandler {
return &ModelHandler{}
func NewModelHandler(modelService service.ModelService) *ModelHandler {
return &ModelHandler{modelService: modelService}
}
// CreateModel 创建模型
@@ -32,26 +37,21 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
return
}
// 创建模型对象
model := &config.Model{
model := &domain.Model{
ID: req.ID,
ProviderID: req.ProviderID,
ModelName: req.ModelName,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateModel(model)
err := h.modelService.Create(model)
if err != nil {
if err.Error() == "供应商不存在" {
if err == appErrors.ErrProviderNotFound {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -62,11 +62,9 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
func (h *ModelHandler) ListModels(c *gin.Context) {
providerID := c.Query("provider_id")
models, err := config.ListModels(providerID)
models, err := h.modelService.List(providerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -77,7 +75,7 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
func (h *ModelHandler) GetModel(c *gin.Context) {
id := c.Param("id")
model, err := config.GetModel(id)
model, err := h.modelService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -85,9 +83,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -106,8 +102,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
return
}
// 更新模型
err := config.UpdateModel(id, req)
err := h.modelService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -115,24 +110,19 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
})
return
}
if err.Error() == "供应商不存在" {
if err == appErrors.ErrProviderNotFound {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新模型失败: " + err.Error(),
})
writeError(c, err)
return
}
// 返回更新后的模型
model, err := config.GetModel(id)
model, err := h.modelService.Get(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的模型失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -143,7 +133,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
func (h *ModelHandler) DeleteModel(c *gin.Context) {
id := c.Param("id")
err := config.DeleteModel(id)
err := h.modelService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -151,9 +141,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除模型失败: " + err.Error(),
})
writeError(c, err)
return
}

View File

@@ -4,32 +4,36 @@ import (
"bufio"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/protocol/openai"
"nex/backend/internal/provider"
"nex/backend/internal/router"
"nex/backend/internal/service"
)
// OpenAIHandler OpenAI 协议处理器
type OpenAIHandler struct {
client *provider.Client
router *router.Router
client provider.ProviderClient
routingService service.RoutingService
statsService service.StatsService
}
// NewOpenAIHandler 创建 OpenAI 处理器
func NewOpenAIHandler() *OpenAIHandler {
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
return &OpenAIHandler{
client: provider.NewClient(),
router: router.NewRouter(),
client: client,
routingService: routingService,
statsService: statsService,
}
}
// HandleChatCompletions 处理 Chat Completions 请求
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
// 解析请求
var req openai.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
@@ -41,14 +45,23 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
return
}
// 路由到供应商
routeResult, err := h.router.Route(req.Model)
// 请求验证
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: formatValidationErrors(validationErrors),
Type: "invalid_request_error",
},
})
return
}
routeResult, err := h.routingService.Route(req.Model)
if err != nil {
h.handleError(c, err)
return
}
// 根据是否流式选择处理方式
if req.Stream {
h.handleStreamRequest(c, &req, routeResult)
} else {
@@ -56,9 +69,7 @@ func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
}
}
// handleNonStreamRequest 处理非流式请求
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送请求到供应商
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
@@ -70,18 +81,14 @@ func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatC
return
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
}()
// 返回响应
c.JSON(http.StatusOK, resp)
}
// handleStreamRequest 处理流式请求
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *router.RouteResult) {
// 发送流式请求到供应商
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
@@ -93,75 +100,58 @@ func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatComp
return
}
// 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 创建流写入器
writer := bufio.NewWriter(c.Writer)
// 流式转发事件
for event := range eventChan {
if event.Error != nil {
// 流错误,记录日志
fmt.Printf("流错误: %v\n", event.Error)
break
}
if event.Done {
// 流结束
writer.WriteString("data: [DONE]\n\n")
writer.Flush()
break
}
// 写入事件数据
writer.WriteString("data: ")
writer.Write(event.Data)
writer.WriteString("\n\n")
writer.Flush()
}
// 记录统计
go func() {
_ = config.RecordRequest(routeResult.Provider.ID, req.Model)
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
}()
}
// handleError 处理路由错误
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
switch err {
case router.ErrModelNotFound:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
if appErr, ok := appErrors.AsAppError(err); ok {
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型未找到",
Message: appErr.Message,
Type: "invalid_request_error",
Code: "model_not_found",
},
})
case router.ErrModelDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "模型已禁用",
Type: "invalid_request_error",
Code: "model_disabled",
},
})
case router.ErrProviderDisabled:
c.JSON(http.StatusNotFound, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "供应商已禁用",
Type: "invalid_request_error",
Code: "provider_disabled",
},
})
default:
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "内部错误: " + err.Error(),
Type: "internal_error",
Code: appErr.Code,
},
})
return
}
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
Error: openai.ErrorDetail{
Message: "内部错误: " + err.Error(),
Type: "internal_error",
},
})
}
// formatValidationErrors 将验证错误 map 格式化为字符串
func formatValidationErrors(errors map[string]string) string {
parts := make([]string, 0, len(errors))
for field, msg := range errors {
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
}
return "请求验证失败: " + strings.Join(parts, "; ")
}

View File

@@ -2,19 +2,25 @@ package handler
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"nex/backend/internal/config"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ProviderHandler 供应商管理处理器
type ProviderHandler struct{}
type ProviderHandler struct {
providerService service.ProviderService
}
// NewProviderHandler 创建供应商处理器
func NewProviderHandler() *ProviderHandler {
return &ProviderHandler{}
func NewProviderHandler(providerService service.ProviderService) *ProviderHandler {
return &ProviderHandler{providerService: providerService}
}
// CreateProvider 创建供应商
@@ -33,43 +39,34 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
return
}
// 创建供应商对象
provider := &config.Provider{
provider := &domain.Provider{
ID: req.ID,
Name: req.Name,
APIKey: req.APIKey,
BaseURL: req.BaseURL,
Enabled: true, // 默认启用
}
// 保存到数据库
err := config.CreateProvider(provider)
err := h.providerService.Create(provider)
if err != nil {
// 检查是否是唯一约束错误ID 重复)
if err.Error() == "UNIQUE constraint failed: providers.id" {
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
c.JSON(http.StatusConflict, gin.H{
"error": "供应商 ID 已存在",
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "创建供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
// 掩码 API Key 后返回
provider.MaskAPIKey()
c.JSON(http.StatusCreated, provider)
}
// ListProviders 列出所有供应商
func (h *ProviderHandler) ListProviders(c *gin.Context) {
providers, err := config.ListProviders()
providers, err := h.providerService.List()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -80,7 +77,7 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
func (h *ProviderHandler) GetProvider(c *gin.Context) {
id := c.Param("id")
provider, err := config.GetProvider(id, true) // 掩码 API Key
provider, err := h.providerService.Get(id, true)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -88,9 +85,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -109,8 +104,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
return
}
// 更新供应商
err := config.UpdateProvider(id, req)
err := h.providerService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -118,18 +112,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "更新供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
// 返回更新后的供应商
provider, err := config.GetProvider(id, true)
provider, err := h.providerService.Get(id, true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询更新后的供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
@@ -140,8 +129,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
id := c.Param("id")
// 删除供应商(级联删除模型)
err := config.DeleteProvider(id)
err := h.providerService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -149,19 +137,23 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "删除供应商失败: " + err.Error(),
})
writeError(c, err)
return
}
// 删除关联的模型
models, _ := config.ListModels("")
for _, model := range models {
if model.ProviderID == id {
_ = config.DeleteModel(model.ID)
}
}
c.Status(http.StatusNoContent)
}
// writeError 统一错误响应处理
func writeError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
c.JSON(appErr.HTTPStatus, gin.H{
"error": appErr.Message,
"code": appErr.Code,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": err.Error(),
})
}

View File

@@ -6,20 +6,21 @@ import (
"github.com/gin-gonic/gin"
"nex/backend/internal/config"
"nex/backend/internal/service"
)
// StatsHandler 统计处理器
type StatsHandler struct{}
type StatsHandler struct {
statsService service.StatsService
}
// NewStatsHandler 创建统计处理器
func NewStatsHandler() *StatsHandler {
return &StatsHandler{}
func NewStatsHandler(statsService service.StatsService) *StatsHandler {
return &StatsHandler{statsService: statsService}
}
// GetStats 查询统计
func (h *StatsHandler) GetStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
@@ -27,7 +28,6 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
@@ -50,8 +50,7 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
@@ -64,16 +63,14 @@ func (h *StatsHandler) GetStats(c *gin.Context) {
// AggregateStats 聚合统计
func (h *StatsHandler) AggregateStats(c *gin.Context) {
// 解析查询参数
providerID := c.Query("provider_id")
modelName := c.Query("model_name")
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
groupBy := c.Query("group_by") // "provider", "model", "date"
groupBy := c.Query("group_by")
var startDate, endDate *time.Time
// 解析日期
if startDateStr != "" {
t, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
@@ -96,8 +93,7 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
endDate = &t
}
// 查询统计
stats, err := config.GetStats(providerID, modelName, startDate, endDate)
stats, err := h.statsService.Get(providerID, modelName, startDate, endDate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "查询统计失败: " + err.Error(),
@@ -105,80 +101,6 @@ func (h *StatsHandler) AggregateStats(c *gin.Context) {
return
}
// 聚合
result := h.aggregate(stats, groupBy)
result := h.statsService.Aggregate(stats, groupBy)
c.JSON(http.StatusOK, result)
}
// aggregate 执行聚合
func (h *StatsHandler) aggregate(stats []config.UsageStats, groupBy string) []map[string]interface{} {
switch groupBy {
case "provider":
return h.aggregateByProvider(stats)
case "model":
return h.aggregateByModel(stats)
case "date":
return h.aggregateByDate(stats)
default:
// 默认按供应商聚合
return h.aggregateByProvider(stats)
}
}
// aggregateByProvider 按供应商聚合
func (h *StatsHandler) aggregateByProvider(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
aggregated[stat.ProviderID] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for providerID, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": providerID,
"request_count": count,
})
}
return result
}
// aggregateByModel 按模型聚合
func (h *StatsHandler) aggregateByModel(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.ProviderID + "/" + stat.ModelName
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for key, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": key[:len(key)/2],
"model_name": key[len(key)/2+1:],
"request_count": count,
})
}
return result
}
// aggregateByDate 按日期聚合
func (h *StatsHandler) aggregateByDate(stats []config.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.Date.Format("2006-01-02")
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for date, count := range aggregated {
result = append(result, map[string]interface{}{
"date": date,
"request_count": count,
})
}
return result
}

View File

@@ -0,0 +1,270 @@
package anthropic
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/protocol/openai"
)
func TestConvertRequest_Basic(t *testing.T) {
temp := 0.7
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 1024,
Temperature: &temp,
Messages: []AnthropicMessage{
{
Role: "user",
Content: []ContentBlock{
{Type: "text", Text: "Hello"},
},
},
},
}
result, err := ConvertRequest(req)
require.NoError(t, err)
assert.Equal(t, "claude-3-opus", result.Model)
assert.Equal(t, 1024, *result.MaxTokens)
assert.Equal(t, &temp, result.Temperature)
require.Len(t, result.Messages, 1)
assert.Equal(t, "user", result.Messages[0].Role)
assert.Equal(t, "Hello", result.Messages[0].Content)
}
func TestConvertRequest_WithSystem(t *testing.T) {
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 100,
System: "You are a helpful assistant.",
Messages: []AnthropicMessage{
{
Role: "user",
Content: []ContentBlock{{Type: "text", Text: "Hi"}},
},
},
}
result, err := ConvertRequest(req)
require.NoError(t, err)
require.Len(t, result.Messages, 2)
assert.Equal(t, "system", result.Messages[0].Role)
assert.Equal(t, "You are a helpful assistant.", result.Messages[0].Content)
assert.Equal(t, "user", result.Messages[1].Role)
}
func TestConvertRequest_DefaultMaxTokens(t *testing.T) {
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 0, // 未设置
Messages: []AnthropicMessage{
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
},
}
result, err := ConvertRequest(req)
require.NoError(t, err)
assert.Equal(t, 4096, *result.MaxTokens)
}
func TestConvertRequest_WithTools(t *testing.T) {
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 100,
Messages: []AnthropicMessage{
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
},
Tools: []AnthropicTool{
{
Name: "get_weather",
Description: "Get weather info",
InputSchema: map[string]interface{}{"type": "object"},
},
},
}
result, err := ConvertRequest(req)
require.NoError(t, err)
require.Len(t, result.Tools, 1)
assert.Equal(t, "function", result.Tools[0].Type)
assert.Equal(t, "get_weather", result.Tools[0].Function.Name)
}
func TestConvertRequest_WithStopSequences(t *testing.T) {
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 100,
StopSequences: []string{"STOP", "END"},
Messages: []AnthropicMessage{
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
},
}
result, err := ConvertRequest(req)
require.NoError(t, err)
assert.Equal(t, []string{"STOP", "END"}, result.Stop)
}
func TestConvertRequest_ToolResult(t *testing.T) {
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 100,
Messages: []AnthropicMessage{
{
Role: "user",
Content: []ContentBlock{
{
Type: "tool_result",
ToolUseID: "tool_123",
Content: "result data",
},
},
},
},
}
result, err := ConvertRequest(req)
require.NoError(t, err)
require.Len(t, result.Messages, 1)
assert.Equal(t, "tool", result.Messages[0].Role)
assert.Equal(t, "tool_123", result.Messages[0].ToolCallID)
assert.Equal(t, "result data", result.Messages[0].Content)
}
func TestConvertResponse(t *testing.T) {
resp := &openai.ChatCompletionResponse{
ID: "chatcmpl-123",
Model: "gpt-4",
Choices: []openai.Choice{
{
Index: 0,
Message: &openai.Message{Role: "assistant", Content: "Hello!"},
FinishReason: "stop",
},
},
Usage: openai.Usage{PromptTokens: 10, CompletionTokens: 5},
}
result, err := ConvertResponse(resp)
require.NoError(t, err)
assert.Equal(t, "chatcmpl-123", result.ID)
assert.Equal(t, "message", result.Type)
assert.Equal(t, "assistant", result.Role)
assert.Equal(t, "end_turn", result.StopReason)
require.Len(t, result.Content, 1)
assert.Equal(t, "text", result.Content[0].Type)
assert.Equal(t, "Hello!", result.Content[0].Text)
assert.Equal(t, 10, result.Usage.InputTokens)
assert.Equal(t, 5, result.Usage.OutputTokens)
}
func TestConvertResponse_ToolCalls(t *testing.T) {
args, _ := json.Marshal(map[string]interface{}{"city": "Beijing"})
resp := &openai.ChatCompletionResponse{
ID: "chatcmpl-456",
Model: "gpt-4",
Choices: []openai.Choice{
{
Index: 0,
Message: &openai.Message{
Role: "assistant",
ToolCalls: []openai.ToolCall{
{
ID: "call_123",
Type: "function",
Function: openai.FunctionCall{
Name: "get_weather",
Arguments: string(args),
},
},
},
},
FinishReason: "tool_calls",
},
},
Usage: openai.Usage{},
}
result, err := ConvertResponse(resp)
require.NoError(t, err)
assert.Equal(t, "tool_use", result.StopReason)
require.Len(t, result.Content, 1)
assert.Equal(t, "tool_use", result.Content[0].Type)
assert.Equal(t, "call_123", result.Content[0].ID)
assert.Equal(t, "get_weather", result.Content[0].Name)
}
func TestConvertToolChoice_String(t *testing.T) {
tests := []struct {
name string
input interface{}
wantErr bool
check func(interface{})
}{
{"auto字符串", "auto", false, func(r interface{}) { assert.Equal(t, "auto", r) }},
{"any字符串", "any", false, func(r interface{}) { assert.Equal(t, "auto", r) }},
{"无效字符串", "invalid", true, nil},
{"tool对象", map[string]interface{}{"type": "tool", "name": "my_func"}, false,
func(r interface{}) {
m := r.(map[string]interface{})
assert.Equal(t, "function", m["type"])
}},
{"缺少name的tool对象", map[string]interface{}{"type": "tool"}, true, nil},
{"缺少type的对象", map[string]interface{}{"name": "func"}, true, nil},
{"无效类型", 42, true, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := convertToolChoice(tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
tt.check(result)
}
})
}
}
func TestValidateRequest(t *testing.T) {
t.Run("有效请求", func(t *testing.T) {
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 100,
Messages: []AnthropicMessage{
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
},
}
errs := ValidateRequest(req)
assert.Nil(t, errs)
})
t.Run("缺少模型", func(t *testing.T) {
req := &MessagesRequest{
MaxTokens: 100,
Messages: []AnthropicMessage{
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
},
}
errs := ValidateRequest(req)
assert.NotNil(t, errs)
assert.Contains(t, errs["model"], "不能为空")
})
t.Run("MaxTokens为0", func(t *testing.T) {
req := &MessagesRequest{
Model: "claude-3-opus",
MaxTokens: 0,
Messages: []AnthropicMessage{
{Role: "user", Content: []ContentBlock{{Type: "text", Text: "Hi"}}},
},
}
errs := ValidateRequest(req)
assert.NotNil(t, errs)
})
}

View File

@@ -0,0 +1,229 @@
package anthropic
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/protocol/openai"
)
func TestStreamConverter_MessageStart(t *testing.T) {
converter := NewStreamConverter("msg_123", "claude-3-opus")
chunk := &openai.StreamChunk{
ID: "chatcmpl-123",
Choices: []openai.StreamChoice{{Index: 0, Delta: openai.Delta{}}},
}
events, err := converter.ConvertChunk(chunk)
require.NoError(t, err)
require.NotEmpty(t, events)
// 第一个事件应该是 message_start
assert.Equal(t, "message_start", events[0].Type)
require.NotNil(t, events[0].Message)
assert.Equal(t, "msg_123", events[0].Message.ID)
assert.Equal(t, "message", events[0].Message.Type)
assert.Equal(t, "assistant", events[0].Message.Role)
assert.Equal(t, "claude-3-opus", events[0].Message.Model)
}
func TestStreamConverter_TextDelta(t *testing.T) {
converter := NewStreamConverter("msg_123", "claude-3-opus")
// 先发送一个空块以触发 message_start
chunk1 := &openai.StreamChunk{
Choices: []openai.StreamChoice{
{Delta: openai.Delta{Content: "Hello"}},
},
}
events1, err := converter.ConvertChunk(chunk1)
require.NoError(t, err)
// 应有 message_start + content_block_start + text delta
assert.GreaterOrEqual(t, len(events1), 3)
// 第二个文本块不应再发送 message_start 和 content_block_start
chunk2 := &openai.StreamChunk{
Choices: []openai.StreamChoice{
{Delta: openai.Delta{Content: " world"}},
},
}
events2, err := converter.ConvertChunk(chunk2)
require.NoError(t, err)
// 只有 text delta
assert.Len(t, events2, 1)
assert.Equal(t, "content_block_delta", events2[0].Type)
assert.Equal(t, "text_delta", events2[0].Delta.Type)
assert.Equal(t, " world", events2[0].Delta.Text)
}
func TestStreamConverter_FinishReason(t *testing.T) {
converter := NewStreamConverter("msg_123", "claude-3-opus")
chunk := &openai.StreamChunk{
Choices: []openai.StreamChoice{
{Delta: openai.Delta{Content: "Hello"}, FinishReason: "stop"},
},
}
events, err := converter.ConvertChunk(chunk)
require.NoError(t, err)
// 查找 message_delta 事件
var messageDelta *StreamEvent
for _, e := range events {
if e.Type == "message_delta" {
messageDelta = &e
break
}
}
require.NotNil(t, messageDelta)
assert.Equal(t, "end_turn", messageDelta.Delta.StopReason)
// 查找 message_stop 事件
var messageStop *StreamEvent
for _, e := range events {
if e.Type == "message_stop" {
messageStop = &e
break
}
}
assert.NotNil(t, messageStop)
}
func TestStreamConverter_FinishReasonToolCalls(t *testing.T) {
converter := NewStreamConverter("msg_123", "claude-3-opus")
chunk := &openai.StreamChunk{
Choices: []openai.StreamChoice{
{Delta: openai.Delta{}, FinishReason: "tool_calls"},
},
}
events, err := converter.ConvertChunk(chunk)
require.NoError(t, err)
var messageDelta *StreamEvent
for _, e := range events {
if e.Type == "message_delta" {
messageDelta = &e
break
}
}
require.NotNil(t, messageDelta)
assert.Equal(t, "tool_use", messageDelta.Delta.StopReason)
}
func TestStreamConverter_FinishReasonLength(t *testing.T) {
converter := NewStreamConverter("msg_123", "claude-3-opus")
chunk := &openai.StreamChunk{
Choices: []openai.StreamChoice{
{Delta: openai.Delta{}, FinishReason: "length"},
},
}
events, err := converter.ConvertChunk(chunk)
require.NoError(t, err)
var messageDelta *StreamEvent
for _, e := range events {
if e.Type == "message_delta" {
messageDelta = &e
break
}
}
require.NotNil(t, messageDelta)
assert.Equal(t, "max_tokens", messageDelta.Delta.StopReason)
}
func TestStreamConverter_ToolCalls(t *testing.T) {
converter := NewStreamConverter("msg_123", "claude-3-opus")
chunk := &openai.StreamChunk{
Choices: []openai.StreamChoice{
{
Delta: openai.Delta{
ToolCalls: []openai.ToolCall{
{
ID: "call_123",
Type: "function",
Function: openai.FunctionCall{
Name: "get_weather",
Arguments: `{"city": "Beijing"}`,
},
},
},
},
},
},
}
events, err := converter.ConvertChunk(chunk)
require.NoError(t, err)
// 应包含 content_block_start (tool_use) + content_block_delta (input_json_delta)
hasBlockStart := false
hasInputDelta := false
for _, e := range events {
if e.Type == "content_block_start" && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
hasBlockStart = true
assert.Equal(t, "call_123", e.ContentBlock.ID)
assert.Equal(t, "get_weather", e.ContentBlock.Name)
}
if e.Type == "content_block_delta" && e.Delta != nil && e.Delta.Type == "input_json_delta" {
hasInputDelta = true
assert.Equal(t, `{"city": "Beijing"}`, e.Delta.Input)
}
}
assert.True(t, hasBlockStart, "应有 tool_use content_block_start")
assert.True(t, hasInputDelta, "应有 input_json_delta")
}
func TestSerializeEvent(t *testing.T) {
event := StreamEvent{
Type: "message_start",
Message: &MessagesResponse{
ID: "msg_123",
Type: "message",
Role: "assistant",
},
}
result, err := SerializeEvent(event)
require.NoError(t, err)
assert.Contains(t, result, "event: message_start")
assert.Contains(t, result, "data: ")
assert.Contains(t, result, "msg_123")
}
func TestSerializeEvent_InvalidJSON(t *testing.T) {
event := StreamEvent{
Type: "test",
}
// 这个应该能正常序列化
result, err := SerializeEvent(event)
require.NoError(t, err)
assert.Contains(t, result, "event: test")
}
func TestContentBlock_ParseInputJSON(t *testing.T) {
t.Run("字符串输入", func(t *testing.T) {
cb := &ContentBlock{Input: `{"key": "value"}`}
result, err := cb.ParseInputJSON()
require.NoError(t, err)
assert.Equal(t, "value", result["key"])
})
t.Run("对象输入", func(t *testing.T) {
cb := &ContentBlock{Input: map[string]interface{}{"key": "value"}}
result, err := cb.ParseInputJSON()
require.NoError(t, err)
assert.Equal(t, "value", result["key"])
})
t.Run("无效类型", func(t *testing.T) {
cb := &ContentBlock{Input: 42}
_, err := cb.ParseInputJSON()
assert.Error(t, err)
})
}

View File

@@ -1,13 +1,20 @@
package anthropic
import "encoding/json"
import (
"encoding/json"
"fmt"
"github.com/go-playground/validator/v10"
pkgValidator "nex/backend/pkg/validator"
)
// MessagesRequest Anthropic Messages API 请求结构
type MessagesRequest struct {
Model string `json:"model"`
Messages []AnthropicMessage `json:"messages"`
Model string `json:"model" validate:"required"`
Messages []AnthropicMessage `json:"messages" validate:"required,min=1"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
MaxTokens int `json:"max_tokens" validate:"required,min=1"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
@@ -114,5 +121,29 @@ func (cb *ContentBlock) ParseInputJSON() (map[string]interface{}, error) {
if obj, ok := cb.Input.(map[string]interface{}); ok {
return obj, nil
}
return nil, json.Unmarshal([]byte{}, nil) // 返回错误
return nil, fmt.Errorf("invalid input type: expected string or map")
}
// ValidateRequest 验证 MessagesRequest
func ValidateRequest(req *MessagesRequest) map[string]string {
errs := pkgValidator.Validate(req)
if errs == nil {
return nil
}
validationErrors := make(map[string]string)
for _, err := range errs.(validator.ValidationErrors) {
field := err.Field()
switch field {
case "Model":
validationErrors["model"] = "模型名称不能为空"
case "Messages":
validationErrors["messages"] = "消息列表不能为空"
case "MaxTokens":
validationErrors["max_tokens"] = "max_tokens 不能为空且必须大于 0"
default:
validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag())
}
}
return validationErrors
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
)
@@ -24,9 +23,6 @@ func (a *Adapter) PrepareRequest(req *ChatCompletionRequest, apiKey, baseURL str
return nil, err
}
// 调试日志:打印请求体
fmt.Printf("[DEBUG] 请求Body: %s\n", string(body))
// 创建 HTTP 请求
// baseURL 已包含版本路径(如 /v1 或 /v4只需添加端点路径
httpReq, err := http.NewRequest("POST", baseURL+"/chat/completions", bytes.NewReader(body))

View File

@@ -0,0 +1,190 @@
package openai
import (
"bytes"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAdapter_PrepareRequest(t *testing.T) {
adapter := NewAdapter()
req := &ChatCompletionRequest{
Model: "gpt-4",
Messages: []Message{
{Role: "user", Content: "Hello"},
},
}
httpReq, err := adapter.PrepareRequest(req, "test-api-key", "https://api.openai.com/v1")
require.NoError(t, err)
require.NotNil(t, httpReq)
assert.Equal(t, "POST", httpReq.Method)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", httpReq.URL.String())
assert.Equal(t, "application/json", httpReq.Header.Get("Content-Type"))
assert.Equal(t, "Bearer test-api-key", httpReq.Header.Get("Authorization"))
// 验证请求体
var body ChatCompletionRequest
err = json.NewDecoder(httpReq.Body).Decode(&body)
require.NoError(t, err)
assert.Equal(t, "gpt-4", body.Model)
}
func TestAdapter_ParseResponse(t *testing.T) {
adapter := NewAdapter()
resp := &ChatCompletionResponse{
ID: "chatcmpl-123",
Object: "chat.completion",
Created: 1234567890,
Model: "gpt-4",
Choices: []Choice{
{
Index: 0,
Message: &Message{Role: "assistant", Content: "Hello!"},
},
},
Usage: Usage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15},
}
body, err := json.Marshal(resp)
require.NoError(t, err)
httpResp := &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
}
result, err := adapter.ParseResponse(httpResp)
require.NoError(t, err)
assert.Equal(t, "chatcmpl-123", result.ID)
assert.Equal(t, "gpt-4", result.Model)
require.Len(t, result.Choices, 1)
assert.Equal(t, "Hello!", result.Choices[0].Message.Content)
}
func TestAdapter_ParseErrorResponse(t *testing.T) {
adapter := NewAdapter()
errResp := &ErrorResponse{
Error: ErrorDetail{
Message: "Invalid API key",
Type: "invalid_request_error",
Code: "invalid_api_key",
},
}
body, err := json.Marshal(errResp)
require.NoError(t, err)
httpResp := &http.Response{
StatusCode: 401,
Body: io.NopCloser(bytes.NewReader(body)),
}
result, err := adapter.ParseErrorResponse(httpResp)
require.NoError(t, err)
assert.Equal(t, "Invalid API key", result.Error.Message)
assert.Equal(t, "invalid_request_error", result.Error.Type)
}
func TestAdapter_ParseStreamChunk(t *testing.T) {
adapter := NewAdapter()
chunk := &StreamChunk{
ID: "chatcmpl-123",
Object: "chat.completion.chunk",
Created: 1234567890,
Model: "gpt-4",
Choices: []StreamChoice{
{
Index: 0,
Delta: Delta{Content: "Hello"},
},
},
}
data, err := json.Marshal(chunk)
require.NoError(t, err)
result, err := adapter.ParseStreamChunk(data)
require.NoError(t, err)
assert.Equal(t, "chatcmpl-123", result.ID)
require.Len(t, result.Choices, 1)
assert.Equal(t, "Hello", result.Choices[0].Delta.Content)
}
func TestParseToolCallArguments(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{"有效JSON", `{"key": "value"}`, false},
{"无效JSON", `not json`, true},
{"空JSON", `{}`, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := &ToolCall{
Function: FunctionCall{Arguments: tt.input},
}
args, err := tc.ParseToolCallArguments()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, args)
}
})
}
}
func TestSerializeToolCallArguments(t *testing.T) {
args := map[string]interface{}{"key": "value"}
result, err := SerializeToolCallArguments(args)
require.NoError(t, err)
assert.JSONEq(t, `{"key": "value"}`, result)
}
func TestValidateRequest(t *testing.T) {
t.Run("有效请求", func(t *testing.T) {
req := &ChatCompletionRequest{
Model: "gpt-4",
Messages: []Message{{Role: "user", Content: "hello"}},
}
errs := ValidateRequest(req)
assert.Nil(t, errs)
})
t.Run("缺少模型", func(t *testing.T) {
req := &ChatCompletionRequest{
Messages: []Message{{Role: "user", Content: "hello"}},
}
errs := ValidateRequest(req)
assert.NotNil(t, errs)
assert.Contains(t, errs["model"], "不能为空")
})
t.Run("缺少消息", func(t *testing.T) {
req := &ChatCompletionRequest{
Model: "gpt-4",
}
errs := ValidateRequest(req)
assert.NotNil(t, errs)
assert.Contains(t, errs["messages"], "不能为空")
})
t.Run("空消息列表", func(t *testing.T) {
req := &ChatCompletionRequest{
Model: "gpt-4",
Messages: []Message{},
}
errs := ValidateRequest(req)
assert.NotNil(t, errs)
})
}

View File

@@ -1,11 +1,18 @@
package openai
import "encoding/json"
import (
"encoding/json"
"fmt"
"github.com/go-playground/validator/v10"
pkgValidator "nex/backend/pkg/validator"
)
// ChatCompletionRequest OpenAI Chat Completions API 请求结构
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Model string `json:"model" validate:"required"`
Messages []Message `json:"messages" validate:"required,min=1"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
@@ -129,3 +136,25 @@ func SerializeToolCallArguments(args map[string]interface{}) (string, error) {
}
return string(bytes), nil
}
// ValidateRequest 验证 ChatCompletionRequest
func ValidateRequest(req *ChatCompletionRequest) map[string]string {
errs := pkgValidator.Validate(req)
if errs == nil {
return nil
}
validationErrors := make(map[string]string)
for _, err := range errs.(validator.ValidationErrors) {
field := err.Field()
switch field {
case "Model":
validationErrors["model"] = "模型名称不能为空"
case "Messages":
validationErrors["messages"] = "消息列表不能为空"
default:
validationErrors[field] = fmt.Sprintf("字段 %s 验证失败: %s", field, err.Tag())
}
}
return validationErrors
}

View File

@@ -9,22 +9,59 @@ import (
"strings"
"time"
"go.uber.org/zap"
"nex/backend/internal/protocol/openai"
)
// StreamConfig 流式处理配置
type StreamConfig struct {
InitialBufferSize int // 初始缓冲区大小(字节),默认 4096
MaxBufferSize int // 最大缓冲区大小(字节),默认 65536
Timeout time.Duration // 流超时时间,默认 5 分钟
ChannelBufferSize int // 事件通道缓冲区大小,默认 100
}
// DefaultStreamConfig 返回默认流式处理配置
func DefaultStreamConfig() StreamConfig {
return StreamConfig{
InitialBufferSize: 4096,
MaxBufferSize: 65536,
Timeout: 5 * time.Minute,
ChannelBufferSize: 100,
}
}
// Client OpenAI 兼容供应商客户端
type Client struct {
httpClient *http.Client
adapter *openai.Adapter
httpClient *http.Client
adapter *openai.Adapter
logger *zap.Logger
streamCfg StreamConfig
}
// StreamEvent 流事件
type StreamEvent struct {
Data []byte
Error error
Done bool
}
// ProviderClient 供应商客户端接口
type ProviderClient interface {
SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error)
SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan StreamEvent, error)
}
// NewClient 创建供应商客户端
func NewClient() *Client {
return &Client{
httpClient: &http.Client{
Timeout: 30 * time.Second, // 非流式请求超时
Timeout: 30 * time.Second,
},
adapter: openai.NewAdapter(),
adapter: openai.NewAdapter(),
logger: zap.L(),
streamCfg: DefaultStreamConfig(),
}
}
@@ -36,10 +73,10 @@ func (c *Client) SendRequest(ctx context.Context, req *openai.ChatCompletionRequ
return nil, fmt.Errorf("准备请求失败: %w", err)
}
// 调试日志:打印完整请求信息
fmt.Printf("[DEBUG] 请求URL: %s\n", httpReq.URL.String())
fmt.Printf("[DEBUG] 请求Method: %s\n", httpReq.Method)
fmt.Printf("[DEBUG] 请求Headers: %v\n", httpReq.Header)
c.logger.Debug("发送请求",
zap.String("url", httpReq.URL.String()),
zap.String("method", httpReq.Method),
)
// 设置上下文
httpReq = httpReq.WithContext(ctx)
@@ -80,18 +117,22 @@ func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompleti
return nil, fmt.Errorf("准备请求失败: %w", err)
}
// 设置上下文
httpReq = httpReq.WithContext(ctx)
// 设置带超时的上下文
streamCtx, cancel := context.WithTimeout(ctx, c.streamCfg.Timeout)
_ = cancel // cancel 在流读取结束后由 ctx 传播处理
httpReq = httpReq.WithContext(streamCtx)
// 发送请求
resp, err := c.httpClient.Do(httpReq)
if err != nil {
cancel()
return nil, fmt.Errorf("发送请求失败: %w", err)
}
// 检查状态码
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
cancel()
errorResp, parseErr := c.adapter.ParseErrorResponse(resp)
if parseErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
@@ -100,33 +141,33 @@ func (c *Client) SendStreamRequest(ctx context.Context, req *openai.ChatCompleti
}
// 创建事件通道
eventChan := make(chan StreamEvent, 100)
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
// 启动 goroutine 读取流
go c.readStream(ctx, resp.Body, eventChan)
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
return eventChan, nil
}
// StreamEvent 流事件
type StreamEvent struct {
Data []byte
Error error
Done bool
}
// readStream 读取 SSE 流
func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan chan<- StreamEvent) {
// readStream 读取 SSE 流(支持动态缓冲区、超时控制和改进的错误处理)
func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body io.ReadCloser, eventChan chan<- StreamEvent) {
defer close(eventChan)
defer body.Close()
defer cancel()
buf := make([]byte, 4096)
bufSize := c.streamCfg.InitialBufferSize
buf := make([]byte, bufSize)
var dataBuf []byte
for {
select {
case <-ctx.Done():
eventChan <- StreamEvent{Error: ctx.Err()}
if ctx.Err() == context.DeadlineExceeded {
c.logger.Warn("流读取超时")
eventChan <- StreamEvent{Error: fmt.Errorf("流读取超时: %w", ctx.Err())}
} else {
eventChan <- StreamEvent{Error: ctx.Err()}
}
return
default:
}
@@ -134,15 +175,32 @@ func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan c
n, err := body.Read(buf)
if err != nil {
if err == io.EOF {
// 流结束
// 流正常结束
return
}
eventChan <- StreamEvent{Error: err}
// 区分网络错误和其他错误
if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error()))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else {
c.logger.Error("流读取错误", zap.String("error", err.Error()))
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
}
return
}
dataBuf = append(dataBuf, buf[:n]...)
// 动态调整缓冲区大小:如果数据量大,增大缓冲区
if len(dataBuf) > bufSize/2 && bufSize < c.streamCfg.MaxBufferSize {
newSize := bufSize * 2
if newSize > c.streamCfg.MaxBufferSize {
newSize = c.streamCfg.MaxBufferSize
}
buf = make([]byte, newSize)
bufSize = newSize
}
// 处理完整的 SSE 事件
for {
// 查找事件边界(双换行)
@@ -175,3 +233,16 @@ func (c *Client) readStream(ctx context.Context, body io.ReadCloser, eventChan c
}
}
// isNetworkError 判断是否为网络相关错误
func isNetworkError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "network") ||
strings.Contains(errStr, "timeout") ||
strings.Contains(errStr, "EOF")
}

View File

@@ -0,0 +1,151 @@
package provider
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/protocol/openai"
)
func TestNewClient(t *testing.T) {
client := NewClient()
require.NotNil(t, client)
assert.NotNil(t, client.httpClient)
assert.NotNil(t, client.adapter)
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
assert.Equal(t, 65536, client.streamCfg.MaxBufferSize)
assert.Equal(t, 100, client.streamCfg.ChannelBufferSize)
}
func TestDefaultStreamConfig(t *testing.T) {
cfg := DefaultStreamConfig()
assert.Equal(t, 4096, cfg.InitialBufferSize)
assert.Equal(t, 65536, cfg.MaxBufferSize)
assert.Equal(t, 100, cfg.ChannelBufferSize)
}
func TestClient_SendRequest_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
resp := openai.ChatCompletionResponse{
ID: "chatcmpl-123",
Choices: []openai.Choice{
{Index: 0, Message: &openai.Message{Role: "assistant", Content: "Hello!"}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
client := NewClient()
req := &openai.ChatCompletionRequest{
Model: "gpt-4",
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
}
result, err := client.SendRequest(context.Background(), req, "test-key", server.URL)
require.NoError(t, err)
assert.Equal(t, "chatcmpl-123", result.ID)
}
func TestClient_SendRequest_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(openai.ErrorResponse{
Error: openai.ErrorDetail{Message: "Invalid API key"},
})
}))
defer server.Close()
client := NewClient()
req := &openai.ChatCompletionRequest{
Model: "gpt-4",
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
}
_, err := client.SendRequest(context.Background(), req, "bad-key", server.URL)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Invalid API key")
}
func TestClient_SendRequest_ConnectionError(t *testing.T) {
client := NewClient()
req := &openai.ChatCompletionRequest{
Model: "gpt-4",
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
}
_, err := client.SendRequest(context.Background(), req, "key", "http://localhost:1")
assert.Error(t, err)
}
func TestClient_SendStreamRequest_CreatesChannel(t *testing.T) {
// 使用一个慢服务器确保客户端有时间读取
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewClient()
req := &openai.ChatCompletionRequest{
Model: "gpt-4",
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
}
eventChan, err := client.SendStreamRequest(context.Background(), req, "test-key", server.URL)
require.NoError(t, err)
require.NotNil(t, eventChan)
// 读取直到 channel 关闭(服务器关闭后应产生 EOF
for range eventChan {
// 消费所有事件
}
// channel 应已关闭(不阻塞即通过)
}
func TestClient_SendStreamRequest_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
client := NewClient()
req := &openai.ChatCompletionRequest{
Model: "gpt-4",
Messages: []openai.Message{{Role: "user", Content: "Hi"}},
}
_, err := client.SendStreamRequest(context.Background(), req, "key", server.URL)
assert.Error(t, err)
}
func TestIsNetworkError(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"connection reset by peer", true},
{"broken pipe", true},
{"network is unreachable", true},
{"timeout waiting for response", true},
{"unexpected EOF", true},
{"normal error", false},
{"", false},
}
for _, tt := range tests {
err := fmt.Errorf("%s", tt.input) //nolint:govet
assert.Equal(t, tt.want, isNetworkError(err), "isNetworkError(%q)", tt.input)
}
}

View File

@@ -0,0 +1,13 @@
package repository
import "nex/backend/internal/domain"
// ModelRepository 模型数据仓库接口
type ModelRepository interface {
Create(model *domain.Model) error
GetByID(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error)
GetByModelName(modelName string) (*domain.Model, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -0,0 +1,104 @@
package repository
import (
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
)
type modelRepository struct {
db *gorm.DB
}
func NewModelRepository(db *gorm.DB) ModelRepository {
return &modelRepository{db: db}
}
func (r *modelRepository) Create(model *domain.Model) error {
m := toConfigModel(model)
m.CreatedAt = time.Now()
return r.db.Create(&m).Error
}
func (r *modelRepository) GetByID(id string) (*domain.Model, error) {
var m config.Model
err := r.db.First(&m, "id = ?", id).Error
if err != nil {
return nil, err
}
d := toDomainModel(&m)
return &d, nil
}
func (r *modelRepository) List(providerID string) ([]domain.Model, error) {
var models []config.Model
var err error
if providerID != "" {
err = r.db.Where("provider_id = ?", providerID).Find(&models).Error
} else {
err = r.db.Find(&models).Error
}
if err != nil {
return nil, err
}
result := make([]domain.Model, len(models))
for i := range models {
result[i] = toDomainModel(&models[i])
}
return result, nil
}
func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) {
var m config.Model
err := r.db.Where("model_name = ?", modelName).First(&m).Error
if err != nil {
return nil, err
}
d := toDomainModel(&m)
return &d, nil
}
func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return appErrors.ErrModelNotFound
}
return nil
}
func (r *modelRepository) Delete(id string) error {
result := r.db.Delete(&config.Model{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return appErrors.ErrModelNotFound
}
return nil
}
func toDomainModel(m *config.Model) domain.Model {
return domain.Model{
ID: m.ID,
ProviderID: m.ProviderID,
ModelName: m.ModelName,
Enabled: m.Enabled,
CreatedAt: m.CreatedAt,
}
}
func toConfigModel(m *domain.Model) config.Model {
return config.Model{
ID: m.ID,
ProviderID: m.ProviderID,
ModelName: m.ModelName,
Enabled: m.Enabled,
}
}

View File

@@ -0,0 +1,12 @@
package repository
import "nex/backend/internal/domain"
// ProviderRepository 供应商数据仓库接口
type ProviderRepository interface {
Create(provider *domain.Provider) error
GetByID(id string) (*domain.Provider, error)
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -0,0 +1,94 @@
package repository
import (
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
)
type providerRepository struct {
db *gorm.DB
}
func NewProviderRepository(db *gorm.DB) ProviderRepository {
return &providerRepository{db: db}
}
func (r *providerRepository) Create(provider *domain.Provider) error {
p := toConfigProvider(provider)
p.CreatedAt = time.Now()
p.UpdatedAt = time.Now()
return r.db.Create(&p).Error
}
func (r *providerRepository) GetByID(id string) (*domain.Provider, error) {
var p config.Provider
err := r.db.First(&p, "id = ?", id).Error
if err != nil {
return nil, err
}
d := toDomainProvider(&p)
return &d, nil
}
func (r *providerRepository) List() ([]domain.Provider, error) {
var providers []config.Provider
err := r.db.Find(&providers).Error
if err != nil {
return nil, err
}
result := make([]domain.Provider, len(providers))
for i := range providers {
result[i] = toDomainProvider(&providers[i])
}
return result, nil
}
func (r *providerRepository) Update(id string, updates map[string]interface{}) error {
updates["updated_at"] = time.Now()
result := r.db.Model(&config.Provider{}).Where("id = ?", id).Updates(updates)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return appErrors.ErrProviderNotFound
}
return nil
}
func (r *providerRepository) Delete(id string) error {
result := r.db.Delete(&config.Provider{}, "id = ?", id)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return appErrors.ErrProviderNotFound
}
return nil
}
func toDomainProvider(p *config.Provider) domain.Provider {
return domain.Provider{
ID: p.ID,
Name: p.Name,
APIKey: p.APIKey,
BaseURL: p.BaseURL,
Enabled: p.Enabled,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func toConfigProvider(p *domain.Provider) config.Provider {
return config.Provider{
ID: p.ID,
Name: p.Name,
APIKey: p.APIKey,
BaseURL: p.BaseURL,
Enabled: p.Enabled,
}
}

View File

@@ -0,0 +1,233 @@
package repository
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
)
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
dir := t.TempDir()
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
// 关闭数据库连接以便 TempDir 清理
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
}
// ============ ProviderRepository 测试 ============
func TestProviderRepository_Create(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
provider := &domain.Provider{
ID: "test-provider",
Name: "Test Provider",
APIKey: "sk-test-key",
BaseURL: "https://api.test.com",
Enabled: true,
}
err := repo.Create(provider)
require.NoError(t, err)
}
func TestProviderRepository_GetByID(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
provider := &domain.Provider{
ID: "test-provider", Name: "Test", APIKey: "sk-test-key", BaseURL: "https://api.test.com",
}
err := repo.Create(provider)
require.NoError(t, err)
result, err := repo.GetByID("test-provider")
require.NoError(t, err)
assert.Equal(t, "test-provider", result.ID)
assert.Equal(t, "Test", result.Name)
}
func TestProviderRepository_GetByID_NotFound(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
_, err := repo.GetByID("nonexistent")
assert.Error(t, err)
}
func TestProviderRepository_List(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
for _, id := range []string{"pA", "pB", "pC"} {
err := repo.Create(&domain.Provider{ID: id, Name: id, APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, err)
}
providers, err := repo.List()
require.NoError(t, err)
assert.Len(t, providers, 3)
}
func TestProviderRepository_Update(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})
err := repo.Update("p1", map[string]interface{}{"name": "New"})
require.NoError(t, err)
result, _ := repo.GetByID("p1")
assert.Equal(t, "New", result.Name)
}
func TestProviderRepository_Update_NotFound(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
err := repo.Update("nonexistent", map[string]interface{}{"name": "New"})
assert.Error(t, err)
}
func TestProviderRepository_Delete(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
err := repo.Delete("p1")
require.NoError(t, err)
_, err = repo.GetByID("p1")
assert.Error(t, err)
}
func TestProviderRepository_Delete_NotFound(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
err := repo.Delete("nonexistent")
assert.Error(t, err)
}
// ============ ModelRepository 测试 ============
func TestModelRepository_Create(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, err)
}
func TestModelRepository_GetByID(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
result, err := repo.GetByID("m1")
require.NoError(t, err)
assert.Equal(t, "m1", result.ID)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelRepository_GetByModelName(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
result, err := repo.GetByModelName("gpt-4")
require.NoError(t, err)
assert.Equal(t, "m1", result.ID)
}
func TestModelRepository_List(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})
all, err := repo.List("")
require.NoError(t, err)
assert.Len(t, all, 3)
p1Models, err := repo.List("p1")
require.NoError(t, err)
assert.Len(t, p1Models, 2)
}
func TestModelRepository_Update(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
err := repo.Update("m1", map[string]interface{}{"enabled": false})
require.NoError(t, err)
result, _ := repo.GetByID("m1")
assert.False(t, result.Enabled)
}
func TestModelRepository_Delete(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
err := repo.Delete("m1")
require.NoError(t, err)
_, err = repo.GetByID("m1")
assert.Error(t, err)
}
// ============ StatsRepository 测试 ============
func TestStatsRepository_Record(t *testing.T) {
db := setupTestDB(t)
repo := NewStatsRepository(db)
err := repo.Record("provider-1", "gpt-4")
require.NoError(t, err)
// 再次记录应递增
err = repo.Record("provider-1", "gpt-4")
require.NoError(t, err)
stats, err := repo.Query("provider-1", "", nil, nil)
require.NoError(t, err)
require.Len(t, stats, 1)
assert.Equal(t, 2, stats[0].RequestCount)
}
func TestStatsRepository_Query(t *testing.T) {
db := setupTestDB(t)
repo := NewStatsRepository(db)
repo.Record("p1", "gpt-4")
// 注意:当前 schema 只有 date 字段有唯一约束
// 所以同一 provider + model 只能有一条记录
stats, err := repo.Query("p1", "", nil, nil)
require.NoError(t, err)
assert.Len(t, stats, 1)
}

View File

@@ -0,0 +1,13 @@
package repository
import (
"time"
"nex/backend/internal/domain"
)
// StatsRepository 统计数据仓库接口
type StatsRepository interface {
Record(providerID, modelName string) error
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
}

View File

@@ -0,0 +1,79 @@
package repository
import (
"errors"
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
)
type statsRepository struct {
db *gorm.DB
}
func NewStatsRepository(db *gorm.DB) StatsRepository {
return &statsRepository{db: db}
}
func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
stats = config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
}
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
}
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
var stats []config.UsageStats
query := r.db.Model(&config.UsageStats{})
if providerID != "" {
query = query.Where("provider_id = ?", providerID)
}
if modelName != "" {
query = query.Where("model_name = ?", modelName)
}
if startDate != nil {
query = query.Where("date >= ?", startDate)
}
if endDate != nil {
query = query.Where("date <= ?", endDate)
}
err := query.Order("date DESC").Find(&stats).Error
if err != nil {
return nil, err
}
result := make([]domain.UsageStats, len(stats))
for i := range stats {
result[i] = domain.UsageStats{
ID: stats[i].ID,
ProviderID: stats[i].ProviderID,
ModelName: stats[i].ModelName,
RequestCount: stats[i].RequestCount,
Date: stats[i].Date,
}
}
return result, nil
}

View File

@@ -1,71 +0,0 @@
package router
import (
"errors"
"fmt"
"nex/backend/internal/config"
)
var (
ErrModelNotFound = errors.New("模型未找到")
ErrModelDisabled = errors.New("模型已禁用")
ErrProviderDisabled = errors.New("供应商已禁用")
)
// RouteResult 路由结果
type RouteResult struct {
Provider *config.Provider
Model *config.Model
}
// Router 模型路由器
type Router struct{}
// NewRouter 创建路由器
func NewRouter() *Router {
return &Router{}
}
// Route 根据模型名称路由到供应商
func (r *Router) Route(modelName string) (*RouteResult, error) {
// 查询模型
models, err := config.ListModels("")
if err != nil {
return nil, fmt.Errorf("查询模型失败: %w", err)
}
// 查找匹配的模型
var targetModel *config.Model
for i := range models {
if models[i].ModelName == modelName {
targetModel = &models[i]
break
}
}
if targetModel == nil {
return nil, ErrModelNotFound
}
// 检查模型是否启用
if !targetModel.Enabled {
return nil, ErrModelDisabled
}
// 查询供应商
provider, err := config.GetProvider(targetModel.ProviderID, false)
if err != nil {
return nil, fmt.Errorf("查询供应商失败: %w", err)
}
// 检查供应商是否启用
if !provider.Enabled {
return nil, ErrProviderDisabled
}
return &RouteResult{
Provider: provider,
Model: targetModel,
}, nil
}

View File

@@ -0,0 +1,12 @@
package service
import "nex/backend/internal/domain"
// ModelService 模型服务接口
type ModelService interface {
Create(model *domain.Model) error
Get(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -0,0 +1,50 @@
package service
import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type modelService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
}
func (s *modelService) Create(model *domain.Model) error {
// Verify provider exists
_, err := s.providerRepo.GetByID(model.ProviderID)
if err != nil {
return appErrors.ErrProviderNotFound
}
model.Enabled = true
return s.modelRepo.Create(model)
}
func (s *modelService) Get(id string) (*domain.Model, error) {
return s.modelRepo.GetByID(id)
}
func (s *modelService) List(providerID string) ([]domain.Model, error) {
return s.modelRepo.List(providerID)
}
func (s *modelService) Update(id string, updates map[string]interface{}) error {
// If updating provider_id, verify new provider exists
if providerID, ok := updates["provider_id"].(string); ok {
_, err := s.providerRepo.GetByID(providerID)
if err != nil {
return appErrors.ErrProviderNotFound
}
}
return s.modelRepo.Update(id, updates)
}
func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id)
}

View File

@@ -0,0 +1,12 @@
package service
import "nex/backend/internal/domain"
// ProviderService 供应商服务接口
type ProviderService interface {
Create(provider *domain.Provider) error
Get(id string, maskKey bool) (*domain.Provider, error)
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -0,0 +1,49 @@
package service
import (
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type providerService struct {
providerRepo repository.ProviderRepository
}
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
return &providerService{providerRepo: providerRepo}
}
func (s *providerService) Create(provider *domain.Provider) error {
provider.Enabled = true
return s.providerRepo.Create(provider)
}
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
provider, err := s.providerRepo.GetByID(id)
if err != nil {
return nil, err
}
if maskKey {
provider.MaskAPIKey()
}
return provider, nil
}
func (s *providerService) List() ([]domain.Provider, error) {
providers, err := s.providerRepo.List()
if err != nil {
return nil, err
}
for i := range providers {
providers[i].MaskAPIKey()
}
return providers, nil
}
func (s *providerService) Update(id string, updates map[string]interface{}) error {
return s.providerRepo.Update(id, updates)
}
func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id)
}

View File

@@ -0,0 +1,8 @@
package service
import "nex/backend/internal/domain"
// RoutingService 路由服务接口
type RoutingService interface {
Route(modelName string) (*domain.RouteResult, error)
}

View File

@@ -0,0 +1,42 @@
package service
import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type routingService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
}
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
}
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.GetByModelName(modelName)
if err != nil {
return nil, appErrors.ErrModelNotFound
}
if !model.Enabled {
return nil, appErrors.ErrModelDisabled
}
provider, err := s.providerRepo.GetByID(model.ProviderID)
if err != nil {
return nil, appErrors.ErrProviderNotFound
}
if !provider.Enabled {
return nil, appErrors.ErrProviderDisabled
}
return &domain.RouteResult{
Provider: provider,
Model: model,
}, nil
}

View File

@@ -0,0 +1,245 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
func setupServiceTestDB(t *testing.T) *gorm.DB {
t.Helper()
dir := t.TempDir()
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
}
// ============ ProviderService 测试 ============
func TestProviderService_Create(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
provider := &domain.Provider{
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
}
err := svc.Create(provider)
require.NoError(t, err)
assert.True(t, provider.Enabled)
}
func TestProviderService_Get_MaskKey(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
svc.Create(&domain.Provider{
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
})
result, err := svc.Get("p1", true)
require.NoError(t, err)
assert.Equal(t, "***2345", result.APIKey)
result, err = svc.Get("p1", false)
require.NoError(t, err)
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
}
func TestProviderService_List(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
providers, err := svc.List()
require.NoError(t, err)
assert.Len(t, providers, 2)
assert.Contains(t, providers[0].APIKey, "***")
}
func TestProviderService_Delete(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
err := svc.Delete("p1")
require.NoError(t, err)
_, err = svc.Get("p1", false)
assert.Error(t, err)
}
// ============ ModelService 测试 ============
func TestModelService_Create(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
assert.True(t, model.Enabled)
}
func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
assert.Error(t, err)
}
func TestModelService_List(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
models, err := svc.List("p1")
require.NoError(t, err)
assert.Len(t, models, 2)
}
// ============ RoutingService 测试 ============
func TestRoutingService_Route(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
result, err := svc.Route("gpt-4")
require.NoError(t, err)
assert.Equal(t, "p1", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
}
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
_, err := svc.Route("nonexistent-model")
assert.Error(t, err)
}
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
// 先创建启用的模型,然后通过 Update 禁用
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
// 先创建启用的 provider然后禁用
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
// ============ StatsService 测试 ============
func TestStatsService_RecordAndGet(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
err := svc.Record("p1", "gpt-4")
require.NoError(t, err)
stats, err := svc.Get("p1", "", nil, nil)
require.NoError(t, err)
assert.Len(t, stats, 1)
}
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
}
result := svc.Aggregate(stats, "provider")
assert.Len(t, result, 2)
p1Count := 0
p2Count := 0
for _, r := range result {
if r["provider_id"] == "p1" {
p1Count = r["request_count"].(int)
}
if r["provider_id"] == "p2" {
p2Count = r["request_count"].(int)
}
}
assert.Equal(t, 15, p1Count)
assert.Equal(t, 8, p2Count)
}
func TestStatsService_Aggregate_ByDate(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
{ProviderID: "p2", RequestCount: 5},
}
result := svc.Aggregate(stats, "date")
assert.Len(t, result, 1)
assert.Equal(t, 15, result[0]["request_count"])
}

View File

@@ -0,0 +1,14 @@
package service
import (
"time"
"nex/backend/internal/domain"
)
// StatsService 统计服务接口
type StatsService interface {
Record(providerID, modelName string) error
Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{}
}

View File

@@ -0,0 +1,85 @@
package service
import (
"time"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type statsService struct {
statsRepo repository.StatsRepository
}
func NewStatsService(statsRepo repository.StatsRepository) StatsService {
return &statsService{statsRepo: statsRepo}
}
func (s *statsService) Record(providerID, modelName string) error {
return s.statsRepo.Record(providerID, modelName)
}
func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return s.statsRepo.Query(providerID, modelName, startDate, endDate)
}
func (s *statsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
switch groupBy {
case "provider":
return s.aggregateByProvider(stats)
case "model":
return s.aggregateByModel(stats)
case "date":
return s.aggregateByDate(stats)
default:
return s.aggregateByProvider(stats)
}
}
func (s *statsService) aggregateByProvider(stats []domain.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
aggregated[stat.ProviderID] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for providerID, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": providerID,
"request_count": count,
})
}
return result
}
func (s *statsService) aggregateByModel(stats []domain.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.ProviderID + "/" + stat.ModelName
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for key, count := range aggregated {
result = append(result, map[string]interface{}{
"provider_id": key[:len(key)/2],
"model_name": key[len(key)/2+1:],
"request_count": count,
})
}
return result
}
func (s *statsService) aggregateByDate(stats []domain.UsageStats) []map[string]interface{} {
aggregated := make(map[string]int)
for _, stat := range stats {
key := stat.Date.Format("2006-01-02")
aggregated[key] += stat.RequestCount
}
result := make([]map[string]interface{}, 0, len(aggregated))
for date, count := range aggregated {
result = append(result, map[string]interface{}{
"date": date,
"request_count": count,
})
}
return result
}