package service import ( "errors" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" "nex/backend/internal/domain" ) type mockModelRepo struct { models map[string]*domain.Model } func (m *mockModelRepo) Create(model *domain.Model) error { return nil } func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) { return nil, nil } func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) { key := providerID + "/" + modelName if model, ok := m.models[key]; ok { return model, nil } return nil, errors.New("not found") } func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) { var result []domain.Model for _, model := range m.models { if providerID == "" || model.ProviderID == providerID { result = append(result, *model) } } return result, nil } func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) { return nil, nil } func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error { return nil } func (m *mockModelRepo) Delete(id string) error { return nil } type mockProviderRepo struct { providers map[string]*domain.Provider } func (m *mockProviderRepo) Create(provider *domain.Provider) error { return nil } func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) { if provider, ok := m.providers[id]; ok { return provider, nil } return nil, errors.New("not found") } func (m *mockProviderRepo) List() ([]domain.Provider, error) { var result []domain.Provider for _, provider := range m.providers { result = append(result, *provider) } return result, nil } func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error { return nil } func (m *mockProviderRepo) Delete(id string) error { return nil } func TestRoutingCache_GetProvider_CacheHit(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, }} cache := NewRoutingCache(modelRepo, providerRepo, logger) provider, err := cache.GetProvider("openai") require.NoError(t, err) assert.Equal(t, "openai", provider.ID) provider2, err := cache.GetProvider("openai") require.NoError(t, err) assert.Equal(t, provider, provider2) } func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} cache := NewRoutingCache(modelRepo, providerRepo, logger) _, err := cache.GetProvider("notexist") assert.Error(t, err) } func TestRoutingCache_GetModel_CacheHit(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: map[string]*domain.Model{ "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, }} providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} cache := NewRoutingCache(modelRepo, providerRepo, logger) model, err := cache.GetModel("openai", "gpt-4") require.NoError(t, err) assert.Equal(t, "gpt-4", model.ModelName) model2, err := cache.GetModel("openai", "gpt-4") require.NoError(t, err) assert.Equal(t, model, model2) } func TestRoutingCache_GetModel_CacheMiss(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} cache := NewRoutingCache(modelRepo, providerRepo, logger) _, err := cache.GetModel("openai", "notexist") assert.Error(t, err) } func TestRoutingCache_DoubleCheck(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)} providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, }} cache := NewRoutingCache(modelRepo, providerRepo, logger) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() _, err := cache.GetProvider("openai") assert.NoError(t, err) }() } wg.Wait() } func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: map[string]*domain.Model{ "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, "openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true}, "anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true}, }} providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, "anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true}, }} cache := NewRoutingCache(modelRepo, providerRepo, logger) _, err := cache.GetModel("openai", "gpt-4") require.NoError(t, err) _, err = cache.GetModel("openai", "gpt-3.5") require.NoError(t, err) _, err = cache.GetModel("anthropic", "claude") require.NoError(t, err) cache.InvalidateProvider("openai") var openaiCount, anthropicCount int cache.models.Range(func(key, value interface{}) bool { if key.(string) == "anthropic/claude" { anthropicCount++ } return true }) assert.Equal(t, 0, openaiCount) assert.Equal(t, 1, anthropicCount) } func TestRoutingCache_InvalidateModel(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: map[string]*domain.Model{ "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, }} providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)} cache := NewRoutingCache(modelRepo, providerRepo, logger) _, err := cache.GetModel("openai", "gpt-4") require.NoError(t, err) cache.InvalidateModel("openai", "gpt-4") var count int cache.models.Range(func(key, value interface{}) bool { count++ return true }) assert.Equal(t, 0, count) } func TestRoutingCache_Preload_Success(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: map[string]*domain.Model{ "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, }} providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, }} cache := NewRoutingCache(modelRepo, providerRepo, logger) err := cache.Preload() require.NoError(t, err) var providerCount, modelCount int cache.providers.Range(func(key, value interface{}) bool { providerCount++ return true }) cache.models.Range(func(key, value interface{}) bool { modelCount++ return true }) assert.Equal(t, 1, providerCount) assert.Equal(t, 1, modelCount) } func TestRoutingCache_ConcurrentAccess(t *testing.T) { logger := zap.NewNop() modelRepo := &mockModelRepo{models: map[string]*domain.Model{ "openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, }} providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{ "openai": {ID: "openai", Name: "OpenAI", Enabled: true}, }} cache := NewRoutingCache(modelRepo, providerRepo, logger) var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { defer wg.Done() _, _ = cache.GetProvider("openai") _, _ = cache.GetModel("openai", "gpt-4") cache.InvalidateProvider("openai") cache.InvalidateModel("openai", "gpt-4") }() } wg.Wait() }