1
0
Files
nex/backend/internal/service/service_test.go
lanyuanxiaoyao df253559a5 feat(cache): 实现 RoutingCache 和 StatsBuffer 优化数据库写入
- 新增 RoutingCache 组件,使用 sync.Map 缓存 Provider 和 Model
- 新增 StatsBuffer 组件,使用 sync.Map + atomic.Int64 缓冲统计数据
- 扩展 StatsRepository.BatchUpdate 支持批量增量更新
- 改造 RoutingService/StatsService/ProviderService/ModelService 集成缓存
- 更新 usage-statistics spec,新增 routing-cache 和 stats-buffer spec
- 新增单元测试覆盖缓存命中/失效/并发场景
2026-04-22 19:24:36 +08:00

505 lines
16 KiB
Go

package service
import (
"errors"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
)
func setupServiceTestDB(t *testing.T) *gorm.DB {
t.Helper()
return testHelpers.SetupTestDB(t)
}
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
t.Helper()
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
}
// ============ RoutingService - RouteByModelName 测试 ============
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
result, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
}
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
_, err := svc.RouteByModelName("openai", "nonexistent-model")
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
}
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
}
// ============ ModelService - Create with UUID 测试 ============
func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
assert.Equal(t, "gpt-4", stored.ModelName)
}
func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
}
// ============ ProviderService - Create with validation 测试 ============
func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
}
func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
assert.True(t, provider.Enabled)
}
// ============ ModelService - Update with duplicate check 测试 ============
func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
err = svc.Create(model2)
require.NoError(t, err)
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
err = svc.Update(model2.ID, map[string]interface{}{
"provider_id": "openai",
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent-id", map[string]interface{}{
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
// 更新 model_name 为不冲突的值
err = svc.Update(model.ID, map[string]interface{}{
"model_name": "gpt-4-turbo",
})
require.NoError(t, err)
updated, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
}
// ============ ProviderService - Update immutable ID 测试 ============
func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
// 尝试更新 id 字段
err = svc.Update("openai", map[string]interface{}{
"id": "new-id",
})
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
}
func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
// 更新 name
err = svc.Update("openai", map[string]interface{}{
"name": "OpenAI Updated",
})
require.NoError(t, err)
updated, err := svc.Get("openai")
require.NoError(t, err)
assert.Equal(t, "OpenAI Updated", updated.Name)
}
// ============ StatsService - Aggregate ByModel 测试 ============
func TestStatsService_Aggregate_ByModel(t *testing.T) {
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "multiple providers with same model name",
stats: []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
},
},
{
name: "empty providerID",
stats: []domain.UsageStats{
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
}
// ============ StatsService - Aggregate ByDate 测试 ============
func TestStatsService_Aggregate_ByDate(t *testing.T) {
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "normal date grouping",
stats: []domain.UsageStats{
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
},
expected: []map[string]interface{}{
{"date": "2024-01-01", "request_count": 15},
{"date": "2024-01-02", "request_count": 20},
},
},
{
name: "zero-value time",
stats: []domain.UsageStats{
{Date: time.Time{}, RequestCount: 10},
{Date: time.Time{}, RequestCount: 5},
},
expected: []map[string]interface{}{
{"date": "0001-01-01", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
if r["date"] == exp["date"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
}
// ============ ProviderService - isUniqueConstraintError 测试 ============
func TestProviderService_isUniqueConstraintError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "UNIQUE constraint failed",
err: errors.New("UNIQUE constraint failed"),
expected: true,
},
{
name: "duplicate key value",
err: errors.New("duplicate key value"),
expected: true,
},
{
name: "UNIQUE constraint case insensitive",
err: errors.New("unique constraint violation"),
expected: true,
},
{
name: "other error",
err: errors.New("some other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isUniqueConstraintError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ ProviderService - List API Key 测试 ============
func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
require.NoError(t, svc.Create(provider1))
require.NoError(t, svc.Create(provider2))
providers, err := svc.List()
require.NoError(t, err)
require.Len(t, providers, 2)
expectedKeys := map[string]string{
"openai": "sk-1234567890",
"anthropic": "sk-anthropic1234",
}
for _, p := range providers {
assert.NotContains(t, p.APIKey, "***")
assert.Equal(t, expectedKeys[p.ID], p.APIKey)
}
}
func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
results := make(chan error, 2)
for i := 0; i < 2; i++ {
go func() {
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
results <- svc.Create(model)
}()
}
err1 := <-results
err2 := <-results
successCount := 0
errorCount := 0
for _, err := range []error{err1, err2} {
if err == nil {
successCount++
} else {
errorCount++
}
}
assert.Equal(t, 1, successCount)
assert.Equal(t, 1, errorCount)
}