- 新增 backend/.golangci.yml 配置 12 个 linter(forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、goimports、gocyclo) - 新增 lefthook.yml 配置 pre-commit hook 自动运行 lint - 修复存量代码违规:errors.Is/As 替换、zap.Error 替换、import 排序、errcheck 修复 - 更新 README 补充编码规范说明 - 归档 backend-code-lint 变更
275 lines
7.5 KiB
Go
275 lines
7.5 KiB
Go
package service
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
|
|
"nex/backend/internal/domain"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type mockModelRepo struct {
|
|
models map[string]*domain.Model
|
|
}
|
|
|
|
func (m *mockModelRepo) Create(model *domain.Model) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
|
key := providerID + "/" + modelName
|
|
if model, ok := m.models[key]; ok {
|
|
return model, nil
|
|
}
|
|
return nil, errors.New("not found")
|
|
}
|
|
|
|
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
|
|
var result []domain.Model
|
|
for _, model := range m.models {
|
|
if providerID == "" || model.ProviderID == providerID {
|
|
result = append(result, *model)
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockModelRepo) Delete(id string) error {
|
|
return nil
|
|
}
|
|
|
|
type mockProviderRepo struct {
|
|
providers map[string]*domain.Provider
|
|
}
|
|
|
|
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
|
|
if provider, ok := m.providers[id]; ok {
|
|
return provider, nil
|
|
}
|
|
return nil, errors.New("not found")
|
|
}
|
|
|
|
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
|
|
var result []domain.Provider
|
|
for _, provider := range m.providers {
|
|
result = append(result, *provider)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockProviderRepo) Delete(id string) error {
|
|
return nil
|
|
}
|
|
|
|
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
provider, err := cache.GetProvider("openai")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "openai", provider.ID)
|
|
|
|
provider2, err := cache.GetProvider("openai")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, provider, provider2)
|
|
}
|
|
|
|
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetProvider("notexist")
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
model, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "gpt-4", model.ModelName)
|
|
|
|
model2, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, model, model2)
|
|
}
|
|
|
|
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetModel("openai", "notexist")
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func TestRoutingCache_DoubleCheck(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, err := cache.GetProvider("openai")
|
|
assert.NoError(t, err)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
|
|
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
_, err = cache.GetModel("openai", "gpt-3.5")
|
|
require.NoError(t, err)
|
|
_, err = cache.GetModel("anthropic", "claude")
|
|
require.NoError(t, err)
|
|
|
|
cache.InvalidateProvider("openai")
|
|
|
|
var openaiCount, anthropicCount int
|
|
cache.models.Range(func(key, value interface{}) bool {
|
|
keyStr, ok := key.(string)
|
|
if ok && keyStr == "anthropic/claude" {
|
|
anthropicCount++
|
|
}
|
|
return true
|
|
})
|
|
assert.Equal(t, 0, openaiCount)
|
|
assert.Equal(t, 1, anthropicCount)
|
|
}
|
|
|
|
func TestRoutingCache_InvalidateModel(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
_, err := cache.GetModel("openai", "gpt-4")
|
|
require.NoError(t, err)
|
|
|
|
cache.InvalidateModel("openai", "gpt-4")
|
|
|
|
var count int
|
|
cache.models.Range(func(key, value interface{}) bool {
|
|
count++
|
|
return true
|
|
})
|
|
assert.Equal(t, 0, count)
|
|
}
|
|
|
|
func TestRoutingCache_Preload_Success(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
err := cache.Preload()
|
|
require.NoError(t, err)
|
|
|
|
var providerCount, modelCount int
|
|
cache.providers.Range(func(key, value interface{}) bool {
|
|
providerCount++
|
|
return true
|
|
})
|
|
cache.models.Range(func(key, value interface{}) bool {
|
|
modelCount++
|
|
return true
|
|
})
|
|
assert.Equal(t, 1, providerCount)
|
|
assert.Equal(t, 1, modelCount)
|
|
}
|
|
|
|
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
|
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
|
}}
|
|
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
|
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
|
}}
|
|
|
|
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, _ = cache.GetProvider("openai")
|
|
_, _ = cache.GetModel("openai", "gpt-4")
|
|
cache.InvalidateProvider("openai")
|
|
cache.InvalidateModel("openai", "gpt-4")
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|