- 新增 RoutingCache 组件,使用 sync.Map 缓存 Provider 和 Model - 新增 StatsBuffer 组件,使用 sync.Map + atomic.Int64 缓冲统计数据 - 扩展 StatsRepository.BatchUpdate 支持批量增量更新 - 改造 RoutingService/StatsService/ProviderService/ModelService 集成缓存 - 更新 usage-statistics spec,新增 routing-cache 和 stats-buffer spec - 新增单元测试覆盖缓存命中/失效/并发场景
505 lines
16 KiB
Go
505 lines
16 KiB
Go
package service
|
|
|
|
import (
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
|
|
testHelpers "nex/backend/tests"
|
|
|
|
"nex/backend/internal/domain"
|
|
"nex/backend/internal/repository"
|
|
appErrors "nex/backend/pkg/errors"
|
|
)
|
|
|
|
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
|
t.Helper()
|
|
return testHelpers.SetupTestDB(t)
|
|
}
|
|
|
|
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
|
|
t.Helper()
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
|
|
}
|
|
|
|
// ============ RoutingService - RouteByModelName 测试 ============
|
|
|
|
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewRoutingService(cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
|
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
|
|
|
result, err := svc.RouteByModelName("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "openai", result.Provider.ID)
|
|
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
|
}
|
|
|
|
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewRoutingService(cache)
|
|
|
|
_, err := svc.RouteByModelName("openai", "nonexistent-model")
|
|
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
|
}
|
|
|
|
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewRoutingService(cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
|
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
|
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
|
|
|
|
_, err := svc.RouteByModelName("openai", "gpt-4")
|
|
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
|
}
|
|
|
|
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewRoutingService(cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
|
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
|
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
|
|
|
|
_, err := svc.RouteByModelName("openai", "gpt-4")
|
|
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
|
}
|
|
|
|
// ============ ModelService - Create with UUID 测试 ============
|
|
|
|
func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewModelService(modelRepo, providerRepo, cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
err := svc.Create(model)
|
|
require.NoError(t, err)
|
|
|
|
assert.NotEmpty(t, model.ID)
|
|
_, err = uuid.Parse(model.ID)
|
|
assert.NoError(t, err, "model.ID should be a valid UUID")
|
|
|
|
stored, err := svc.Get(model.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, model.ID, stored.ID)
|
|
assert.Equal(t, "gpt-4", stored.ModelName)
|
|
}
|
|
|
|
func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewModelService(modelRepo, providerRepo, cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
err := svc.Create(model1)
|
|
require.NoError(t, err)
|
|
|
|
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
err = svc.Create(model2)
|
|
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
|
}
|
|
|
|
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewModelService(modelRepo, providerRepo, cache)
|
|
|
|
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
|
err := svc.Create(model)
|
|
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
|
|
}
|
|
|
|
// ============ ProviderService - Create with validation 测试 ============
|
|
|
|
func TestProviderService_Create_InvalidID(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
repo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewProviderService(repo, modelRepo, cache)
|
|
|
|
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
|
err := svc.Create(provider)
|
|
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
|
|
}
|
|
|
|
func TestProviderService_Create_ValidID(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
repo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewProviderService(repo, modelRepo, cache)
|
|
|
|
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
|
err := svc.Create(provider)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "openai", provider.ID)
|
|
assert.True(t, provider.Enabled)
|
|
}
|
|
|
|
// ============ ModelService - Update with duplicate check 测试 ============
|
|
|
|
func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewModelService(modelRepo, providerRepo, cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
|
|
|
|
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
err := svc.Create(model1)
|
|
require.NoError(t, err)
|
|
|
|
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
|
|
err = svc.Create(model2)
|
|
require.NoError(t, err)
|
|
|
|
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
|
|
err = svc.Update(model2.ID, map[string]interface{}{
|
|
"provider_id": "openai",
|
|
"model_name": "gpt-4",
|
|
})
|
|
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
|
}
|
|
|
|
func TestModelService_Update_ModelNotFound(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewModelService(modelRepo, providerRepo, cache)
|
|
|
|
err := svc.Update("nonexistent-id", map[string]interface{}{
|
|
"model_name": "gpt-4",
|
|
})
|
|
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
|
}
|
|
|
|
func TestModelService_Update_Success(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewModelService(modelRepo, providerRepo, cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
err := svc.Create(model)
|
|
require.NoError(t, err)
|
|
|
|
// 更新 model_name 为不冲突的值
|
|
err = svc.Update(model.ID, map[string]interface{}{
|
|
"model_name": "gpt-4-turbo",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
updated, err := svc.Get(model.ID)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
|
|
}
|
|
|
|
// ============ ProviderService - Update immutable ID 测试 ============
|
|
|
|
func TestProviderService_Update_ImmutableID(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
repo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewProviderService(repo, modelRepo, cache)
|
|
|
|
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
|
err := svc.Create(provider)
|
|
require.NoError(t, err)
|
|
|
|
// 尝试更新 id 字段
|
|
err = svc.Update("openai", map[string]interface{}{
|
|
"id": "new-id",
|
|
})
|
|
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
|
|
}
|
|
|
|
func TestProviderService_Update_Success(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
repo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewProviderService(repo, modelRepo, cache)
|
|
|
|
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
|
err := svc.Create(provider)
|
|
require.NoError(t, err)
|
|
|
|
// 更新 name
|
|
err = svc.Update("openai", map[string]interface{}{
|
|
"name": "OpenAI Updated",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
updated, err := svc.Get("openai")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "OpenAI Updated", updated.Name)
|
|
}
|
|
|
|
// ============ StatsService - Aggregate ByModel 测试 ============
|
|
|
|
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
stats []domain.UsageStats
|
|
expected []map[string]interface{}
|
|
}{
|
|
{
|
|
name: "multiple providers with same model name",
|
|
stats: []domain.UsageStats{
|
|
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
|
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
|
|
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
|
|
},
|
|
expected: []map[string]interface{}{
|
|
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
|
|
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
|
|
},
|
|
},
|
|
{
|
|
name: "empty providerID",
|
|
stats: []domain.UsageStats{
|
|
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
|
|
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
|
|
},
|
|
expected: []map[string]interface{}{
|
|
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
|
|
},
|
|
},
|
|
{
|
|
name: "empty result set",
|
|
stats: []domain.UsageStats{},
|
|
expected: []map[string]interface{}{},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
statsRepo := repository.NewStatsRepository(db)
|
|
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
|
|
|
result := svc.Aggregate(tt.stats, "model")
|
|
|
|
assert.Len(t, result, len(tt.expected))
|
|
for _, exp := range tt.expected {
|
|
found := false
|
|
for _, r := range result {
|
|
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
|
|
assert.Equal(t, exp["request_count"], r["request_count"])
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, found, "expected result not found: %v", exp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============ StatsService - Aggregate ByDate 测试 ============
|
|
|
|
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
stats []domain.UsageStats
|
|
expected []map[string]interface{}
|
|
}{
|
|
{
|
|
name: "normal date grouping",
|
|
stats: []domain.UsageStats{
|
|
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
|
|
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
|
|
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
|
|
},
|
|
expected: []map[string]interface{}{
|
|
{"date": "2024-01-01", "request_count": 15},
|
|
{"date": "2024-01-02", "request_count": 20},
|
|
},
|
|
},
|
|
{
|
|
name: "zero-value time",
|
|
stats: []domain.UsageStats{
|
|
{Date: time.Time{}, RequestCount: 10},
|
|
{Date: time.Time{}, RequestCount: 5},
|
|
},
|
|
expected: []map[string]interface{}{
|
|
{"date": "0001-01-01", "request_count": 15},
|
|
},
|
|
},
|
|
{
|
|
name: "empty result set",
|
|
stats: []domain.UsageStats{},
|
|
expected: []map[string]interface{}{},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
statsRepo := repository.NewStatsRepository(db)
|
|
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
|
|
|
result := svc.Aggregate(tt.stats, "date")
|
|
|
|
assert.Len(t, result, len(tt.expected))
|
|
for _, exp := range tt.expected {
|
|
found := false
|
|
for _, r := range result {
|
|
if r["date"] == exp["date"] {
|
|
assert.Equal(t, exp["request_count"], r["request_count"])
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, found, "expected result not found: %v", exp)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============ ProviderService - isUniqueConstraintError 测试 ============
|
|
|
|
func TestProviderService_isUniqueConstraintError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "nil error",
|
|
err: nil,
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "UNIQUE constraint failed",
|
|
err: errors.New("UNIQUE constraint failed"),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "duplicate key value",
|
|
err: errors.New("duplicate key value"),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "UNIQUE constraint case insensitive",
|
|
err: errors.New("unique constraint violation"),
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "other error",
|
|
err: errors.New("some other error"),
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := isUniqueConstraintError(tt.err)
|
|
assert.Equal(t, tt.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============ ProviderService - List API Key 测试 ============
|
|
|
|
func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
repo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewProviderService(repo, modelRepo, cache)
|
|
|
|
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
|
|
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
|
|
require.NoError(t, svc.Create(provider1))
|
|
require.NoError(t, svc.Create(provider2))
|
|
|
|
providers, err := svc.List()
|
|
require.NoError(t, err)
|
|
require.Len(t, providers, 2)
|
|
|
|
expectedKeys := map[string]string{
|
|
"openai": "sk-1234567890",
|
|
"anthropic": "sk-anthropic1234",
|
|
}
|
|
for _, p := range providers {
|
|
assert.NotContains(t, p.APIKey, "***")
|
|
assert.Equal(t, expectedKeys[p.ID], p.APIKey)
|
|
}
|
|
}
|
|
|
|
func TestModelService_ConcurrentCreate(t *testing.T) {
|
|
db := setupServiceTestDB(t)
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
modelRepo := repository.NewModelRepository(db)
|
|
cache := setupRoutingCache(t, db)
|
|
svc := NewModelService(modelRepo, providerRepo, cache)
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
results := make(chan error, 2)
|
|
for i := 0; i < 2; i++ {
|
|
go func() {
|
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
results <- svc.Create(model)
|
|
}()
|
|
}
|
|
|
|
err1 := <-results
|
|
err2 := <-results
|
|
|
|
successCount := 0
|
|
errorCount := 0
|
|
for _, err := range []error{err1, err2} {
|
|
if err == nil {
|
|
successCount++
|
|
} else {
|
|
errorCount++
|
|
}
|
|
}
|
|
assert.Equal(t, 1, successCount)
|
|
assert.Equal(t, 1, errorCount)
|
|
}
|