- 新增 RoutingCache 组件,使用 sync.Map 缓存 Provider 和 Model - 新增 StatsBuffer 组件,使用 sync.Map + atomic.Int64 缓冲统计数据 - 扩展 StatsRepository.BatchUpdate 支持批量增量更新 - 改造 RoutingService/StatsService/ProviderService/ModelService 集成缓存 - 更新 usage-statistics spec,新增 routing-cache 和 stats-buffer spec - 新增单元测试覆盖缓存命中/失效/并发场景
274 lines
7.4 KiB
Go
274 lines
7.4 KiB
Go
package service
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
|
|
"nex/backend/internal/domain"
|
|
)
|
|
|
|
type mockModelRepo struct {
|
|
models map[string]*domain.Model
|
|
}
|
|
|
|
func (m *mockModelRepo) Create(model *domain.Model) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
|
key := providerID + "/" + modelName
|
|
if model, ok := m.models[key]; ok {
|
|
return model, nil
|
|
}
|
|
return nil, errors.New("not found")
|
|
}
|
|
|
|
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
|
|
var result []domain.Model
|
|
for _, model := range m.models {
|
|
if providerID == "" || model.ProviderID == providerID {
|
|
result = append(result, *model)
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockModelRepo) Delete(id string) error {
|
|
return nil
|
|
}
|
|
|
|
type mockProviderRepo struct {
|
|
providers map[string]*domain.Provider
|
|
}
|
|
|
|
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
|
|
if provider, ok := m.providers[id]; ok {
|
|
return provider, nil
|
|
}
|
|
return nil, errors.New("not found")
|
|
}
|
|
|
|
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
|
|
var result []domain.Provider
|
|
for _, provider := range m.providers {
|
|
result = append(result, *provider)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockProviderRepo) Delete(id string) error {
|
|
return nil
|
|
}
|
|
|
|
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
provider, err := cache.GetProvider("openai")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "openai", provider.ID)
|
|
|
|
provider2, err := cache.GetProvider("openai")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, provider, provider2)
|
|
}
|
|
|
|
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetProvider("notexist")
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
model, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "gpt-4", model.ModelName)
|
|
|
|
model2, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, model, model2)
|
|
}
|
|
|
|
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetModel("openai", "notexist")
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestRoutingCache_DoubleCheck(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, err := cache.GetProvider("openai")
|
|
assert.NoError(t, err)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
|
|
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
_, err = cache.GetModel("openai", "gpt-3.5")
|
|
require.NoError(t, err)
|
|
_, err = cache.GetModel("anthropic", "claude")
|
|
require.NoError(t, err)
|
|
|
|
cache.InvalidateProvider("openai")
|
|
|
|
var openaiCount, anthropicCount int
|
|
cache.models.Range(func(key, value interface{}) bool {
|
|
if key.(string) == "anthropic/claude" {
|
|
anthropicCount++
|
|
}
|
|
return true
|
|
})
|
|
assert.Equal(t, 0, openaiCount)
|
|
assert.Equal(t, 1, anthropicCount)
|
|
}
|
|
|
|
func TestRoutingCache_InvalidateModel(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
|
|
cache.InvalidateModel("openai", "gpt-4")
|
|
|
|
var count int
|
|
cache.models.Range(func(key, value interface{}) bool {
|
|
count++
|
|
return true
|
|
})
|
|
assert.Equal(t, 0, count)
|
|
}
|
|
|
|
func TestRoutingCache_Preload_Success(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
err := cache.Preload()
|
|
require.NoError(t, err)
|
|
|
|
var providerCount, modelCount int
|
|
cache.providers.Range(func(key, value interface{}) bool {
|
|
providerCount++
|
|
return true
|
|
})
|
|
cache.models.Range(func(key, value interface{}) bool {
|
|
modelCount++
|
|
return true
|
|
})
|
|
assert.Equal(t, 1, providerCount)
|
|
assert.Equal(t, 1, modelCount)
|
|
}
|
|
|
|
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, _ = cache.GetProvider("openai")
|
|
_, _ = cache.GetModel("openai", "gpt-4")
|
|
cache.InvalidateProvider("openai")
|
|
cache.InvalidateModel("openai", "gpt-4")
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|