## 高优先级修复 - stats_service_impl: 使用 strings.SplitN 替代错误的索引分割 - provider_handler: 使用 errors.Is(err, gorm.ErrDuplicatedKey) 替代字符串匹配 - client: 重写 isNetworkError 使用 errors.As/Is 类型安全判断 - proxy_handler: 使用 encoding/json 标准库解析 JSON(extractModelName、isStreamRequest) ## 中优先级修复 - stats_handler: 添加 parseDateParam 辅助函数消除重复日期解析 - pkg/errors: 新增 ErrRequestCreate/Send/ResponseRead 错误类型和 WithCause 方法 - client: 使用结构化错误替代 fmt.Errorf - ConversionEngine: logger 依赖注入,替换所有 zap.L() 调用 ## 低优先级修复 - encoder: 删除 joinStrings,使用 strings.Join - adapter: 删除 modelInfoRegex 正则,使用 isModelInfoPath 字符串函数 ## 文档更新 - README.md: 添加公共库使用指南和编码规范章节 - specs: 同步 delta specs 到 main specs(error-handling、structured-logging、request-validation) ## 归档 - openspec/changes/archive/2026-04-20-refactor-backend-code-quality/
271 lines
8.2 KiB
Go
271 lines
8.2 KiB
Go
package service
|
||
|
||
import (
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
|
||
"nex/backend/internal/config"
|
||
"nex/backend/internal/domain"
|
||
"nex/backend/internal/repository"
|
||
)
|
||
|
||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||
t.Helper()
|
||
dir := t.TempDir()
|
||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||
require.NoError(t, err)
|
||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||
require.NoError(t, err)
|
||
t.Cleanup(func() {
|
||
sqlDB, _ := db.DB()
|
||
if sqlDB != nil {
|
||
sqlDB.Close()
|
||
}
|
||
})
|
||
return db
|
||
}
|
||
|
||
// ============ ProviderService 测试 ============
|
||
|
||
func TestProviderService_Create(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
repo := repository.NewProviderRepository(db)
|
||
svc := NewProviderService(repo)
|
||
|
||
provider := &domain.Provider{
|
||
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
|
||
}
|
||
err := svc.Create(provider)
|
||
require.NoError(t, err)
|
||
assert.True(t, provider.Enabled)
|
||
}
|
||
|
||
func TestProviderService_Get_MaskKey(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
repo := repository.NewProviderRepository(db)
|
||
svc := NewProviderService(repo)
|
||
|
||
svc.Create(&domain.Provider{
|
||
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
|
||
})
|
||
|
||
result, err := svc.Get("p1", true)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "***2345", result.APIKey)
|
||
|
||
result, err = svc.Get("p1", false)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
|
||
}
|
||
|
||
func TestProviderService_List(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
repo := repository.NewProviderRepository(db)
|
||
svc := NewProviderService(repo)
|
||
|
||
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
|
||
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
|
||
|
||
providers, err := svc.List()
|
||
require.NoError(t, err)
|
||
assert.Len(t, providers, 2)
|
||
assert.Contains(t, providers[0].APIKey, "***")
|
||
}
|
||
|
||
func TestProviderService_Delete(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
repo := repository.NewProviderRepository(db)
|
||
svc := NewProviderService(repo)
|
||
|
||
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||
err := svc.Delete("p1")
|
||
require.NoError(t, err)
|
||
|
||
_, err = svc.Get("p1", false)
|
||
assert.Error(t, err)
|
||
}
|
||
|
||
// ============ ModelService 测试 ============
|
||
|
||
func TestModelService_Create(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
providerRepo := repository.NewProviderRepository(db)
|
||
modelRepo := repository.NewModelRepository(db)
|
||
svc := NewModelService(modelRepo, providerRepo)
|
||
|
||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||
|
||
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
|
||
err := svc.Create(model)
|
||
require.NoError(t, err)
|
||
assert.True(t, model.Enabled)
|
||
}
|
||
|
||
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
providerRepo := repository.NewProviderRepository(db)
|
||
modelRepo := repository.NewModelRepository(db)
|
||
svc := NewModelService(modelRepo, providerRepo)
|
||
|
||
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||
err := svc.Create(model)
|
||
assert.Error(t, err)
|
||
}
|
||
|
||
func TestModelService_List(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
providerRepo := repository.NewProviderRepository(db)
|
||
modelRepo := repository.NewModelRepository(db)
|
||
svc := NewModelService(modelRepo, providerRepo)
|
||
|
||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||
|
||
models, err := svc.List("p1")
|
||
require.NoError(t, err)
|
||
assert.Len(t, models, 2)
|
||
}
|
||
|
||
// ============ RoutingService 测试 ============
|
||
|
||
func TestRoutingService_Route(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
providerRepo := repository.NewProviderRepository(db)
|
||
modelRepo := repository.NewModelRepository(db)
|
||
svc := NewRoutingService(modelRepo, providerRepo)
|
||
|
||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||
|
||
result, err := svc.Route("gpt-4")
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "p1", result.Provider.ID)
|
||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||
}
|
||
|
||
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
providerRepo := repository.NewProviderRepository(db)
|
||
modelRepo := repository.NewModelRepository(db)
|
||
svc := NewRoutingService(modelRepo, providerRepo)
|
||
|
||
_, err := svc.Route("nonexistent-model")
|
||
assert.Error(t, err)
|
||
}
|
||
|
||
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
providerRepo := repository.NewProviderRepository(db)
|
||
modelRepo := repository.NewModelRepository(db)
|
||
svc := NewRoutingService(modelRepo, providerRepo)
|
||
|
||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||
// 先创建启用的模型,然后通过 Update 禁用
|
||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||
|
||
_, err := svc.Route("gpt-4")
|
||
assert.Error(t, err)
|
||
}
|
||
|
||
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
providerRepo := repository.NewProviderRepository(db)
|
||
modelRepo := repository.NewModelRepository(db)
|
||
svc := NewRoutingService(modelRepo, providerRepo)
|
||
|
||
// 先创建启用的 provider,然后禁用
|
||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
|
||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||
|
||
_, err := svc.Route("gpt-4")
|
||
assert.Error(t, err)
|
||
}
|
||
|
||
// ============ StatsService 测试 ============
|
||
|
||
func TestStatsService_RecordAndGet(t *testing.T) {
|
||
db := setupServiceTestDB(t)
|
||
statsRepo := repository.NewStatsRepository(db)
|
||
svc := NewStatsService(statsRepo)
|
||
|
||
err := svc.Record("p1", "gpt-4")
|
||
require.NoError(t, err)
|
||
|
||
stats, err := svc.Get("p1", "", nil, nil)
|
||
require.NoError(t, err)
|
||
assert.Len(t, stats, 1)
|
||
}
|
||
|
||
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
|
||
statsRepo := repository.NewStatsRepository(nil)
|
||
svc := NewStatsService(statsRepo)
|
||
|
||
stats := []domain.UsageStats{
|
||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
|
||
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
|
||
}
|
||
|
||
result := svc.Aggregate(stats, "provider")
|
||
assert.Len(t, result, 2)
|
||
|
||
p1Count := 0
|
||
p2Count := 0
|
||
for _, r := range result {
|
||
if r["provider_id"] == "p1" {
|
||
p1Count = r["request_count"].(int)
|
||
}
|
||
if r["provider_id"] == "p2" {
|
||
p2Count = r["request_count"].(int)
|
||
}
|
||
}
|
||
assert.Equal(t, 15, p1Count)
|
||
assert.Equal(t, 8, p2Count)
|
||
}
|
||
|
||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||
statsRepo := repository.NewStatsRepository(nil)
|
||
svc := NewStatsService(statsRepo)
|
||
|
||
stats := []domain.UsageStats{
|
||
{ProviderID: "p1", RequestCount: 10},
|
||
{ProviderID: "p2", RequestCount: 5},
|
||
}
|
||
|
||
result := svc.Aggregate(stats, "date")
|
||
assert.Len(t, result, 1)
|
||
assert.Equal(t, 15, result[0]["request_count"])
|
||
}
|
||
|
||
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||
statsRepo := repository.NewStatsRepository(nil)
|
||
svc := NewStatsService(statsRepo)
|
||
|
||
stats := []domain.UsageStats{
|
||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
||
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
|
||
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
|
||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
|
||
}
|
||
|
||
result := svc.Aggregate(stats, "model")
|
||
assert.Len(t, result, 3)
|
||
|
||
// 验证每个 provider/model 组合的计数
|
||
counts := make(map[string]int)
|
||
for _, r := range result {
|
||
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
|
||
counts[key] = r["request_count"].(int)
|
||
}
|
||
assert.Equal(t, 13, counts["openai/gpt-4"])
|
||
assert.Equal(t, 5, counts["openai/gpt-3.5"])
|
||
assert.Equal(t, 8, counts["anthropic/claude-3"])
|
||
}
|