Merge branch 'dev-database-write-optimization'
This commit is contained in:
@@ -67,13 +67,25 @@ func main() {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
// 5. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
// 5. 初始化缓存
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
// 6. 创建 ConversionEngine
|
||||
// 6. 初始化统计缓冲
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
|
||||
// 7. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
// 8. 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
@@ -83,16 +95,16 @@ func main() {
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
// 7. 初始化 provider client
|
||||
// 9. 初始化 provider client
|
||||
providerClient := provider.NewClient()
|
||||
|
||||
// 8. 初始化 handler 层
|
||||
// 10. 初始化 handler 层
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
// 9. 创建 Gin 引擎
|
||||
// 11. 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
@@ -103,7 +115,7 @@ func main() {
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
|
||||
// 10. 启动服务器
|
||||
// 12. 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: formatAddr(cfg.Server.Port),
|
||||
Handler: r,
|
||||
@@ -131,6 +143,8 @@ func main() {
|
||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
statsBuffer.Stop()
|
||||
|
||||
zapLogger.Info("服务器已关闭")
|
||||
}
|
||||
|
||||
|
||||
@@ -11,5 +11,6 @@ import (
|
||||
// StatsRepository 统计数据仓库接口
|
||||
type StatsRepository interface {
|
||||
Record(providerID, modelName string) error
|
||||
BatchUpdate(providerID, modelName string, date time.Time, delta int) error
|
||||
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,28 @@ func (r *statsRepository) Record(providerID, modelName string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats config.UsageStats
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, date).First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return tx.Create(&config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: delta,
|
||||
Date: date,
|
||||
}).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Model(&stats).
|
||||
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
var stats []config.UsageStats
|
||||
query := r.db.Model(&config.UsageStats{})
|
||||
|
||||
@@ -11,27 +11,30 @@ import (
|
||||
type modelService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
|
||||
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService {
|
||||
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache}
|
||||
}
|
||||
|
||||
func (s *modelService) Create(model *domain.Model) error {
|
||||
// 校验供应商存在
|
||||
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
// 联合唯一校验:同一供应商下 model_name 不重复
|
||||
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 自动生成 UUID 作为 id
|
||||
model.ID = uuid.New().String()
|
||||
model.Enabled = true
|
||||
return s.modelRepo.Create(model)
|
||||
err := s.modelRepo.Create(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.SetModel(model)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelService) Get(id string) (*domain.Model, error) {
|
||||
@@ -47,20 +50,17 @@ func (s *modelService) ListEnabled() ([]domain.Model, error) {
|
||||
}
|
||||
|
||||
func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||
// 获取当前模型
|
||||
current, err := s.modelRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
// 如果更新 provider_id,校验新供应商存在
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
if _, err := s.providerRepo.GetByID(providerID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
}
|
||||
|
||||
// 确定更新后的 provider_id 和 model_name
|
||||
newProviderID := current.ProviderID
|
||||
if v, ok := updates["provider_id"].(string); ok {
|
||||
newProviderID = v
|
||||
@@ -70,18 +70,37 @@ func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||
newModelName = v
|
||||
}
|
||||
|
||||
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
|
||||
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.modelRepo.Update(id, updates)
|
||||
err = s.modelRepo.Update(id, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.InvalidateModel(current.ProviderID, current.ModelName)
|
||||
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||
s.cache.InvalidateModel(newProviderID, newModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(id string) error {
|
||||
return s.modelRepo.Delete(id)
|
||||
model, err := s.modelRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
err = s.modelRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.InvalidateModel(model.ProviderID, model.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
|
||||
|
||||
@@ -13,23 +13,27 @@ import (
|
||||
type providerService struct {
|
||||
providerRepo repository.ProviderRepository
|
||||
modelRepo repository.ModelRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
|
||||
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache}
|
||||
}
|
||||
|
||||
func (s *providerService) Create(provider *domain.Provider) error {
|
||||
// 校验 provider_id 字符集
|
||||
if err := modelid.ValidateProviderID(provider.ID); err != nil {
|
||||
return appErrors.ErrInvalidProviderID
|
||||
}
|
||||
provider.Enabled = true
|
||||
err := s.providerRepo.Create(provider)
|
||||
if err != nil && isUniqueConstraintError(err) {
|
||||
return appErrors.ErrConflict
|
||||
if err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return appErrors.ErrConflict
|
||||
}
|
||||
return err
|
||||
}
|
||||
return err
|
||||
s.cache.SetProvider(provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *providerService) Get(id string) (*domain.Provider, error) {
|
||||
@@ -44,11 +48,21 @@ func (s *providerService) Update(id string, updates map[string]interface{}) erro
|
||||
if _, ok := updates["id"]; ok {
|
||||
return appErrors.ErrImmutableField
|
||||
}
|
||||
return s.providerRepo.Update(id, updates)
|
||||
err := s.providerRepo.Update(id, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.InvalidateProvider(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *providerService) Delete(id string) error {
|
||||
return s.providerRepo.Delete(id)
|
||||
err := s.providerRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.InvalidateProvider(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
|
||||
|
||||
134
backend/internal/service/routing_cache.go
Normal file
134
backend/internal/service/routing_cache.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type RoutingCache struct {
|
||||
providers sync.Map
|
||||
models sync.Map
|
||||
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewRoutingCache(
|
||||
modelRepo repository.ModelRepository,
|
||||
providerRepo repository.ProviderRepository,
|
||||
logger *zap.Logger,
|
||||
) *RoutingCache {
|
||||
return &RoutingCache{
|
||||
modelRepo: modelRepo,
|
||||
providerRepo: providerRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
}
|
||||
|
||||
provider, err := c.providerRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
}
|
||||
|
||||
c.providers.Store(id, provider)
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) {
|
||||
key := providerID + "/" + modelName
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
}
|
||||
|
||||
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
}
|
||||
|
||||
c.models.Store(key, model)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (c *RoutingCache) SetProvider(provider *domain.Provider) {
|
||||
c.providers.Store(provider.ID, provider)
|
||||
}
|
||||
|
||||
func (c *RoutingCache) SetModel(model *domain.Model) {
|
||||
key := model.ProviderID + "/" + model.ModelName
|
||||
c.models.Store(key, model)
|
||||
}
|
||||
|
||||
func (c *RoutingCache) InvalidateProvider(id string) {
|
||||
c.providers.Delete(id)
|
||||
c.invalidateModelsByProvider(id)
|
||||
c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id))
|
||||
}
|
||||
|
||||
func (c *RoutingCache) InvalidateModel(providerID, modelName string) {
|
||||
key := providerID + "/" + modelName
|
||||
c.models.Delete(key)
|
||||
c.logger.Debug("Model 缓存失效",
|
||||
zap.String("provider_id", providerID),
|
||||
zap.String("model_name", modelName))
|
||||
}
|
||||
|
||||
func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
|
||||
prefix := providerID + "/"
|
||||
count := 0
|
||||
c.models.Range(func(key, value interface{}) bool {
|
||||
if strings.HasPrefix(key.(string), prefix) {
|
||||
c.models.Delete(key)
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
c.logger.Debug("清除 Provider 相关 Model 缓存",
|
||||
zap.String("provider_id", providerID),
|
||||
zap.Int("count", count))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoutingCache) Preload() error {
|
||||
providers, err := c.providerRepo.List()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range providers {
|
||||
c.providers.Store(providers[i].ID, &providers[i])
|
||||
}
|
||||
|
||||
models, err := c.modelRepo.List("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range models {
|
||||
key := models[i].ProviderID + "/" + models[i].ModelName
|
||||
c.models.Store(key, &models[i])
|
||||
}
|
||||
|
||||
c.logger.Info("缓存预热完成",
|
||||
zap.Int("providers", len(providers)),
|
||||
zap.Int("models", len(models)))
|
||||
return nil
|
||||
}
|
||||
273
backend/internal/service/routing_cache_test.go
Normal file
273
backend/internal/service/routing_cache_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockModelRepo struct {
|
||||
models map[string]*domain.Model
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Create(model *domain.Model) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||
key := providerID + "/" + modelName
|
||||
if model, ok := m.models[key]; ok {
|
||||
return model, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
|
||||
var result []domain.Model
|
||||
for _, model := range m.models {
|
||||
if providerID == "" || model.ProviderID == providerID {
|
||||
result = append(result, *model)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockProviderRepo struct {
|
||||
providers map[string]*domain.Provider
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
|
||||
if provider, ok := m.providers[id]; ok {
|
||||
return provider, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
|
||||
var result []domain.Provider
|
||||
for _, provider := range m.providers {
|
||||
result = append(result, *provider)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
provider, err := cache.GetProvider("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai", provider.ID)
|
||||
|
||||
provider2, err := cache.GetProvider("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, provider, provider2)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetProvider("notexist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
model, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", model.ModelName)
|
||||
|
||||
model2, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model, model2)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "notexist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingCache_DoubleCheck(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := cache.GetProvider("openai")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
|
||||
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
_, err = cache.GetModel("openai", "gpt-3.5")
|
||||
require.NoError(t, err)
|
||||
_, err = cache.GetModel("anthropic", "claude")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.InvalidateProvider("openai")
|
||||
|
||||
var openaiCount, anthropicCount int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
if key.(string) == "anthropic/claude" {
|
||||
anthropicCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 0, openaiCount)
|
||||
assert.Equal(t, 1, anthropicCount)
|
||||
}
|
||||
|
||||
func TestRoutingCache_InvalidateModel(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.InvalidateModel("openai", "gpt-4")
|
||||
|
||||
var count int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestRoutingCache_Preload_Success(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
err := cache.Preload()
|
||||
require.NoError(t, err)
|
||||
|
||||
var providerCount, modelCount int
|
||||
cache.providers.Range(func(key, value interface{}) bool {
|
||||
providerCount++
|
||||
return true
|
||||
})
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
modelCount++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 1, providerCount)
|
||||
assert.Equal(t, 1, modelCount)
|
||||
}
|
||||
|
||||
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetProvider("openai")
|
||||
_, _ = cache.GetModel("openai", "gpt-4")
|
||||
cache.InvalidateProvider("openai")
|
||||
cache.InvalidateModel("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -4,20 +4,18 @@ import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type routingService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
|
||||
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
func NewRoutingService(cache *RoutingCache) RoutingService {
|
||||
return &routingService{cache: cache}
|
||||
}
|
||||
|
||||
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
model, err := s.cache.GetModel(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrModelNotFound
|
||||
}
|
||||
@@ -26,7 +24,7 @@ func (s *routingService) RouteByModelName(providerID, modelName string) (*domain
|
||||
return nil, appErrors.ErrModelDisabled
|
||||
}
|
||||
|
||||
provider, err := s.providerRepo.GetByID(model.ProviderID)
|
||||
provider, err := s.cache.GetProvider(model.ProviderID)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
@@ -14,7 +14,8 @@ func TestProviderService_Update(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
|
||||
@@ -30,7 +31,8 @@ func TestProviderService_Update_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
|
||||
assert.Error(t, err)
|
||||
@@ -40,7 +42,8 @@ func TestModelService_Get(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
@@ -55,7 +58,8 @@ func TestModelService_Update(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
@@ -73,7 +77,8 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
@@ -87,7 +92,8 @@ func TestModelService_Delete(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
@@ -104,7 +110,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Delete("nonexistent")
|
||||
assert.Error(t, err)
|
||||
@@ -112,7 +119,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
|
||||
|
||||
func TestStatsService_Aggregate_Default(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
buffer := NewStatsBuffer(statsRepo, nil)
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
@@ -133,7 +141,8 @@ func TestModelService_Update_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
|
||||
assert.Error(t, err)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
testHelpers "nex/backend/tests"
|
||||
@@ -22,13 +23,21 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
return testHelpers.SetupTestDB(t)
|
||||
}
|
||||
|
||||
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
|
||||
t.Helper()
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
|
||||
}
|
||||
|
||||
// ============ RoutingService - RouteByModelName 测试 ============
|
||||
|
||||
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
@@ -41,9 +50,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
|
||||
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "nonexistent-model")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
@@ -53,7 +61,8 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
@@ -67,7 +76,8 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
@@ -83,7 +93,8 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
@@ -91,12 +102,10 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证返回的 model 拥有有效的 UUID
|
||||
assert.NotEmpty(t, model.ID)
|
||||
_, err = uuid.Parse(model.ID)
|
||||
assert.NoError(t, err, "model.ID should be a valid UUID")
|
||||
|
||||
// 通过 Get 验证持久化
|
||||
stored, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.ID, stored.ID)
|
||||
@@ -107,7 +116,8 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
@@ -115,7 +125,6 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
|
||||
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err = svc.Create(model2)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
@@ -125,7 +134,8 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
@@ -138,7 +148,8 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -149,7 +160,8 @@ func TestProviderService_Create_ValidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -164,7 +176,8 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
|
||||
@@ -189,7 +202,8 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent-id", map[string]interface{}{
|
||||
"model_name": "gpt-4",
|
||||
@@ -201,7 +215,8 @@ func TestModelService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
@@ -226,7 +241,8 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -243,7 +259,8 @@ func TestProviderService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -301,7 +318,7 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "model")
|
||||
|
||||
@@ -362,7 +379,7 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "date")
|
||||
|
||||
@@ -431,7 +448,8 @@ func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
|
||||
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
|
||||
@@ -456,7 +474,8 @@ func TestModelService_ConcurrentCreate(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
|
||||
169
backend/internal/service/stats_buffer.go
Normal file
169
backend/internal/service/stats_buffer.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type StatsBuffer struct {
|
||||
counters sync.Map
|
||||
|
||||
flushInterval time.Duration
|
||||
flushThreshold int
|
||||
totalCount atomic.Int64
|
||||
|
||||
statsRepo repository.StatsRepository
|
||||
logger *zap.Logger
|
||||
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
type StatsBufferOption func(*StatsBuffer)
|
||||
|
||||
func WithFlushInterval(d time.Duration) StatsBufferOption {
|
||||
return func(b *StatsBuffer) {
|
||||
b.flushInterval = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithFlushThreshold(threshold int) StatsBufferOption {
|
||||
return func(b *StatsBuffer) {
|
||||
b.flushThreshold = threshold
|
||||
}
|
||||
}
|
||||
|
||||
func NewStatsBuffer(
|
||||
statsRepo repository.StatsRepository,
|
||||
logger *zap.Logger,
|
||||
opts ...StatsBufferOption,
|
||||
) *StatsBuffer {
|
||||
b := &StatsBuffer{
|
||||
statsRepo: statsRepo,
|
||||
logger: logger,
|
||||
flushInterval: 5 * time.Second,
|
||||
flushThreshold: 100,
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(b)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Increment(providerID, modelName string) {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
key := providerID + "/" + modelName + "/" + today
|
||||
|
||||
var counter *int64
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter = v.(*int64)
|
||||
} else {
|
||||
val := int64(0)
|
||||
counter = &val
|
||||
actual, loaded := b.counters.LoadOrStore(key, counter)
|
||||
if loaded {
|
||||
counter = actual.(*int64)
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddInt64(counter, 1)
|
||||
|
||||
if b.totalCount.Add(1) >= int64(b.flushThreshold) {
|
||||
go b.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Start() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(b.flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
b.flush()
|
||||
case <-b.stopCh:
|
||||
b.flush()
|
||||
close(b.doneCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Stop() {
|
||||
close(b.stopCh)
|
||||
<-b.doneCh
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) flush() {
|
||||
type statEntry struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date string
|
||||
count int64
|
||||
}
|
||||
|
||||
var entries []statEntry
|
||||
b.counters.Range(func(key, value interface{}) bool {
|
||||
keyStr := key.(string)
|
||||
parts := strings.Split(keyStr, "/")
|
||||
if len(parts) != 3 {
|
||||
return true
|
||||
}
|
||||
|
||||
counter := value.(*int64)
|
||||
count := atomic.SwapInt64(counter, 0)
|
||||
|
||||
if count > 0 {
|
||||
entries = append(entries, statEntry{
|
||||
providerID: parts[0],
|
||||
modelName: parts[1],
|
||||
date: parts[2],
|
||||
count: count,
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
success := 0
|
||||
for _, entry := range entries {
|
||||
date, _ := time.Parse("2006-01-02", entry.date)
|
||||
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
|
||||
if err != nil {
|
||||
b.logger.Error("批量更新统计失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
zap.String("model_name", entry.modelName),
|
||||
zap.Int64("count", entry.count),
|
||||
zap.Error(err))
|
||||
|
||||
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter := v.(*int64)
|
||||
atomic.AddInt64(counter, entry.count)
|
||||
}
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
|
||||
b.totalCount.Store(0)
|
||||
b.logger.Debug("统计刷新完成",
|
||||
zap.Int("total", len(entries)),
|
||||
zap.Int("success", success),
|
||||
zap.Int("failed", len(entries)-success))
|
||||
}
|
||||
251
backend/internal/service/stats_buffer_test.go
Normal file
251
backend/internal/service/stats_buffer_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockStatsRepo struct {
|
||||
records []struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date time.Time
|
||||
delta int
|
||||
}
|
||||
fail bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) Record(providerID, modelName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.fail {
|
||||
return errors.New("db error")
|
||||
}
|
||||
m.records = append(m.records, struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date time.Time
|
||||
delta int
|
||||
}{providerID, modelName, date, delta})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestStatsBuffer_Increment(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-3.5")
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count += atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(100), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_LoadOrStore(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var counterCount int
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counterCount++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 1, counterCount)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FlushByInterval(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(100*time.Millisecond))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FlushByThreshold(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushThreshold(10))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_SwapInt64(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
var beforeCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
beforeCount = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), beforeCount)
|
||||
|
||||
buffer.flush()
|
||||
|
||||
var afterCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
afterCount = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(0), afterCount)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FailRetry(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{fail: true}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
buffer.flush()
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_Stop(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(10*time.Second))
|
||||
|
||||
buffer.Start()
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
start := time.Now()
|
||||
buffer.Stop()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 1*time.Second)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(50*time.Millisecond))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
totalDelta := 0
|
||||
for _, r := range statsRepo.records {
|
||||
totalDelta += r.delta
|
||||
}
|
||||
statsRepo.mu.Unlock()
|
||||
|
||||
assert.Equal(t, 100, totalDelta)
|
||||
}
|
||||
@@ -10,14 +10,16 @@ import (
|
||||
|
||||
type statsService struct {
|
||||
statsRepo repository.StatsRepository
|
||||
buffer *StatsBuffer
|
||||
}
|
||||
|
||||
func NewStatsService(statsRepo repository.StatsRepository) StatsService {
|
||||
return &statsService{statsRepo: statsRepo}
|
||||
func NewStatsService(statsRepo repository.StatsRepository, buffer *StatsBuffer) StatsService {
|
||||
return &statsService{statsRepo: statsRepo, buffer: buffer}
|
||||
}
|
||||
|
||||
func (s *statsService) Record(providerID, modelName string) error {
|
||||
return s.statsRepo.Record(providerID, modelName)
|
||||
s.buffer.Increment(providerID, modelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -54,10 +55,14 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
logger := zap.NewNop()
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
// 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
@@ -48,10 +49,14 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
logger := zap.NewNop()
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
require.NoError(t, registry.Register(openaiConv.NewAdapter()))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
@@ -30,10 +31,14 @@ func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
_ = service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
logger := zap.NewNop()
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
_ = service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
|
||||
@@ -41,6 +41,20 @@ func (m *MockStatsRepository) EXPECT() *MockStatsRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BatchUpdate mocks base method.
|
||||
func (m *MockStatsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchUpdate", providerID, modelName, date, delta)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchUpdate indicates an expected call of BatchUpdate.
|
||||
func (mr *MockStatsRepositoryMockRecorder) BatchUpdate(providerID, modelName, date, delta any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdate", reflect.TypeOf((*MockStatsRepository)(nil).BatchUpdate), providerID, modelName, date, delta)
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockStatsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
136
openspec/specs/routing-cache/spec.md
Normal file
136
openspec/specs/routing-cache/spec.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# Routing Cache
|
||||
|
||||
## Purpose
|
||||
|
||||
TBD - 为 Provider 和 Model 提供内存缓存,优化路由查询性能。
|
||||
|
||||
## Requirements
|
||||
|
||||
### Requirement: RoutingCache 缓存 Provider 和 Model
|
||||
|
||||
系统 SHALL 为 Provider 和 Model 提供内存缓存,使用 sync.Map 作为缓存数据结构。
|
||||
|
||||
#### Scenario: 缓存数据结构
|
||||
|
||||
- **WHEN** 创建 RoutingCache
|
||||
- **THEN** SHALL 使用 sync.Map 存储 Provider(key = providerID)
|
||||
- **THEN** SHALL 使用 sync.Map 存储 Model(key = providerID/modelName)
|
||||
|
||||
### Requirement: 缓存查询优先
|
||||
|
||||
RoutingCache SHALL 优先从缓存查询,缓存未命中时查询数据库并更新缓存。
|
||||
|
||||
#### Scenario: 缓存命中
|
||||
|
||||
- **WHEN** 查询 Provider 或 Model
|
||||
- **THEN** SHALL 先从缓存查询
|
||||
- **THEN** 如果缓存命中,SHALL 直接返回缓存值
|
||||
- **THEN** SHALL 不查询数据库
|
||||
|
||||
#### Scenario: 缓存未命中
|
||||
|
||||
- **WHEN** 查询 Provider 或 Model
|
||||
- **THEN** SHALL 先从缓存查询
|
||||
- **THEN** 如果缓存未命中,SHALL 查询数据库
|
||||
- **THEN** SHALL 将查询结果存入缓存
|
||||
- **THEN** SHALL 返回查询结果
|
||||
|
||||
#### Scenario: 双重检查防止竞态
|
||||
|
||||
- **WHEN** 并发查询同一个 Provider 或 Model
|
||||
- **THEN** SHALL 在查询数据库后再次检查缓存
|
||||
- **THEN** 如果已有其他 goroutine 存入缓存,SHALL 使用缓存值
|
||||
- **THEN** SHALL 防止存入旧值
|
||||
|
||||
### Requirement: 缓存更新
|
||||
|
||||
RoutingCache SHALL 支持 Create 操作后更新缓存。
|
||||
|
||||
#### Scenario: Create Provider 后更新缓存
|
||||
|
||||
- **WHEN** 创建 Provider 成功
|
||||
- **THEN** SHALL 调用 RoutingCache.SetProvider
|
||||
- **THEN** SHALL 将新 Provider 存入缓存
|
||||
|
||||
#### Scenario: Create Model 后更新缓存
|
||||
|
||||
- **WHEN** 创建 Model 成功
|
||||
- **THEN** SHALL 调用 RoutingCache.SetModel
|
||||
- **THEN** SHALL 将新 Model 存入缓存
|
||||
|
||||
### Requirement: 缓存失效
|
||||
|
||||
RoutingCache SHALL 支持 Update/Delete 操作后清除缓存。
|
||||
|
||||
#### Scenario: Update Provider 后清除缓存
|
||||
|
||||
- **WHEN** 更新 Provider 成功
|
||||
- **THEN** SHALL 调用 RoutingCache.InvalidateProvider
|
||||
- **THEN** SHALL 清除该 Provider 的缓存
|
||||
- **THEN** SHALL 级联清除该 Provider 的所有 Model 缓存
|
||||
|
||||
#### Scenario: Delete Provider 后清除缓存
|
||||
|
||||
- **WHEN** 删除 Provider 成功
|
||||
- **THEN** SHALL 调用 RoutingCache.InvalidateProvider
|
||||
- **THEN** SHALL 清除该 Provider 的缓存
|
||||
- **THEN** SHALL 级联清除该 Provider 的所有 Model 缓存
|
||||
|
||||
#### Scenario: Update Model 后清除缓存
|
||||
|
||||
- **WHEN** 更新 Model 成功
|
||||
- **THEN** SHALL 调用 RoutingCache.InvalidateModel
|
||||
- **THEN** SHALL 清除旧位置的 Model 缓存
|
||||
- **THEN** 如果 provider_id 或 model_name 变化,SHALL 也清除新位置的缓存
|
||||
|
||||
#### Scenario: Delete Model 后清除缓存
|
||||
|
||||
- **WHEN** 删除 Model 成功
|
||||
- **THEN** SHALL 调用 RoutingCache.InvalidateModel
|
||||
- **THEN** SHALL 清除该 Model 的缓存
|
||||
|
||||
### Requirement: 缓存预热
|
||||
|
||||
RoutingCache SHALL 支持启动时预热缓存。
|
||||
|
||||
#### Scenario: 预热成功
|
||||
|
||||
- **WHEN** 服务启动时
|
||||
- **THEN** SHALL 调用 RoutingCache.Preload
|
||||
- **THEN** SHALL 从数据库加载所有 Provider 到缓存
|
||||
- **THEN** SHALL 从数据库加载所有 Model 到缓存
|
||||
- **THEN** SHALL 记录预热完成的日志
|
||||
|
||||
#### Scenario: 预热失败
|
||||
|
||||
- **WHEN** 预热失败时
|
||||
- **THEN** SHALL 记录警告日志
|
||||
- **THEN** SHALL 继续启动服务
|
||||
- **THEN** SHALL 使用懒加载(首次查询时加载)
|
||||
|
||||
### Requirement: RoutingService 使用缓存
|
||||
|
||||
RoutingService SHALL 使用 RoutingCache 进行路由查询。
|
||||
|
||||
#### Scenario: RouteByModelName 使用缓存
|
||||
|
||||
- **WHEN** 调用 RoutingService.RouteByModelName
|
||||
- **THEN** SHALL 调用 RoutingCache.GetModel 获取 Model
|
||||
- **THEN** SHALL 调用 RoutingCache.GetProvider 获取 Provider
|
||||
- **THEN** SHALL 不直接调用 Repository
|
||||
|
||||
### Requirement: 并发安全
|
||||
|
||||
RoutingCache SHALL 支持并发访问。
|
||||
|
||||
#### Scenario: 并发查询
|
||||
|
||||
- **WHEN** 多个 goroutine 并发查询缓存
|
||||
- **THEN** SHALL 无竞态条件
|
||||
- **THEN** SHALL 无 panic
|
||||
|
||||
#### Scenario: 并发查询和失效
|
||||
|
||||
- **WHEN** 并发查询和失效缓存
|
||||
- **THEN** SHALL 无竞态条件
|
||||
- **THEN** SHALL 保证一致性
|
||||
156
openspec/specs/stats-buffer/spec.md
Normal file
156
openspec/specs/stats-buffer/spec.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# Stats Buffer
|
||||
|
||||
## Purpose
|
||||
|
||||
TBD - 为统计数据提供内存缓冲,优化写入性能。
|
||||
|
||||
## Requirements
|
||||
|
||||
### Requirement: StatsBuffer 内存缓冲
|
||||
|
||||
系统 SHALL 为统计数据提供内存缓冲,使用 sync.Map + atomic.Int64 进行计数。
|
||||
|
||||
#### Scenario: 缓冲数据结构
|
||||
|
||||
- **WHEN** 创建 StatsBuffer
|
||||
- **THEN** SHALL 使用 sync.Map 存储计数器(key = providerID/modelName/date)
|
||||
- **THEN** SHALL 使用 atomic.Int64 进行计数
|
||||
- **THEN** SHALL 支持配置刷新间隔和阈值
|
||||
|
||||
### Requirement: 原子计数
|
||||
|
||||
StatsBuffer SHALL 使用原子操作进行计数。
|
||||
|
||||
#### Scenario: Increment 计数
|
||||
|
||||
- **WHEN** 调用 StatsBuffer.Increment
|
||||
- **THEN** SHALL 使用 atomic.AddInt64 增加计数
|
||||
- **THEN** SHALL 无锁操作
|
||||
- **THEN** SHALL 线程安全
|
||||
|
||||
#### Scenario: Increment 创建计数器
|
||||
|
||||
- **WHEN** 调用 StatsBuffer.Increment 且计数器不存在
|
||||
- **THEN** SHALL 使用 sync.Map.LoadOrStore 创建计数器
|
||||
- **THEN** SHALL 初始化计数器为 0
|
||||
- **THEN** SHALL 原子增加计数
|
||||
|
||||
#### Scenario: 并发计数
|
||||
|
||||
- **WHEN** 多个 goroutine 并发 Increment
|
||||
- **THEN** SHALL 计数准确
|
||||
- **THEN** SHALL 无竞态条件
|
||||
|
||||
### Requirement: 定时刷新
|
||||
|
||||
StatsBuffer SHALL 支持定时刷新到数据库。
|
||||
|
||||
#### Scenario: 定时刷新触发
|
||||
|
||||
- **WHEN** 后台刷新协程运行
|
||||
- **THEN** SHALL 每隔 flushInterval 触发刷新
|
||||
- **THEN** SHALL 调用 StatsRepository.BatchUpdate 写入数据库
|
||||
|
||||
#### Scenario: 刷新间隔配置
|
||||
|
||||
- **WHEN** 创建 StatsBuffer
|
||||
- **THEN** 默认 flushInterval 为 5 秒
|
||||
- **THEN** 可通过 WithFlushInterval 选项配置
|
||||
|
||||
### Requirement: 阈值触发刷新
|
||||
|
||||
StatsBuffer SHALL 支持累计阈值触发刷新。
|
||||
|
||||
#### Scenario: 阈值触发
|
||||
|
||||
- **WHEN** 累计计数达到 flushThreshold
|
||||
- **THEN** SHALL 异步触发刷新
|
||||
- **THEN** SHALL 不阻塞请求
|
||||
|
||||
#### Scenario: 阈值配置
|
||||
|
||||
- **WHEN** 创建 StatsBuffer
|
||||
- **THEN** 默认 flushThreshold 为 100
|
||||
- **THEN** 可通过 WithFlushThreshold 选项配置
|
||||
|
||||
### Requirement: 批量写入数据库
|
||||
|
||||
StatsBuffer SHALL 批量写入统计数据到数据库。
|
||||
|
||||
#### Scenario: 批量写入
|
||||
|
||||
- **WHEN** 刷新触发
|
||||
- **THEN** SHALL 遍历所有计数器
|
||||
- **THEN** SHALL 使用 atomic.SwapInt64(counter, 0) 获取并清零计数器
|
||||
- **THEN** SHALL 调用 StatsRepository.BatchUpdate 批量写入
|
||||
- **THEN** SHALL 重置 totalCount 为 0
|
||||
|
||||
#### Scenario: SwapInt64 清零计数器
|
||||
|
||||
- **WHEN** flush 收集计数器
|
||||
- **THEN** SHALL 使用 SwapInt64 原子操作获取当前计数并清零
|
||||
- **THEN** SHALL 保证计数不丢失(新计数会累加到已清零的计数器)
|
||||
- **THEN** SHALL 不阻塞后续 Increment 操作
|
||||
|
||||
#### Scenario: 写入失败保留计数器
|
||||
|
||||
- **WHEN** BatchUpdate 失败
|
||||
- **THEN** SHALL 将计数加回计数器(使用 atomic.AddInt64)
|
||||
- **THEN** SHALL 记录错误日志
|
||||
- **THEN** SHALL 继续处理其他条目
|
||||
|
||||
### Requirement: 优雅关闭
|
||||
|
||||
StatsBuffer SHALL 支持优雅关闭,确保最后的统计写入数据库。
|
||||
|
||||
#### Scenario: Stop 等待刷新完成
|
||||
|
||||
- **WHEN** 调用 StatsBuffer.Stop
|
||||
- **THEN** SHALL 停止后台刷新协程
|
||||
- **THEN** SHALL 执行最后一次刷新
|
||||
- **THEN** SHALL 等待刷新完成
|
||||
- **THEN** SHALL 保证统计数据不丢失
|
||||
|
||||
#### Scenario: 无超时
|
||||
|
||||
- **WHEN** Stop 等待刷新
|
||||
- **THEN** SHALL 无超时限制
|
||||
- **THEN** SHALL 等待刷新完成
|
||||
|
||||
### Requirement: StatsService 使用缓冲
|
||||
|
||||
StatsService SHALL 使用 StatsBuffer 进行统计记录。
|
||||
|
||||
#### Scenario: Record 使用缓冲
|
||||
|
||||
- **WHEN** 调用 StatsService.Record
|
||||
- **THEN** SHALL 调用 StatsBuffer.Increment
|
||||
- **THEN** SHALL 不直接调用 StatsRepository.Record
|
||||
- **THEN** SHALL 立即返回,不阻塞
|
||||
|
||||
### Requirement: StatsRepository 扩展
|
||||
|
||||
StatsRepository SHALL 新增 BatchUpdate 方法。
|
||||
|
||||
#### Scenario: BatchUpdate 方法
|
||||
|
||||
- **WHEN** 调用 StatsRepository.BatchUpdate
|
||||
- **THEN** SHALL 使用事务更新或创建统计记录
|
||||
- **THEN** SHALL 使用 request_count + delta 更新
|
||||
- **THEN** SHALL 支持批量增量更新
|
||||
|
||||
### Requirement: 并发安全
|
||||
|
||||
StatsBuffer SHALL 支持并发访问。
|
||||
|
||||
#### Scenario: 并发 Increment 和 flush
|
||||
|
||||
- **WHEN** 并发 Increment 和 flush
|
||||
- **THEN** SHALL 无竞态条件
|
||||
- **THEN** SHALL 计数准确(可能延迟到下次 flush)
|
||||
|
||||
#### Scenario: flush 期间 Increment
|
||||
|
||||
- **WHEN** flush 正在执行
|
||||
- **THEN** 新的 Increment SHALL 继续计数
|
||||
- **THEN** SHALL 不会丢失计数
|
||||
@@ -13,14 +13,15 @@
|
||||
#### Scenario: 记录成功请求
|
||||
|
||||
- **WHEN** 请求成功转发到供应商
|
||||
- **THEN** 网关 SHALL 增加该供应商和模型的请求计数
|
||||
- **THEN** 网关 SHALL 记录当前日期的统计
|
||||
- **THEN** 网关 SHALL 通过 StatsBuffer 增加该供应商和模型的请求计数
|
||||
- **THEN** 网关 SHALL 异步批量写入数据库(定时或阈值触发)
|
||||
- **THEN** 网关 SHALL 不阻塞响应
|
||||
|
||||
#### Scenario: 记录流式请求
|
||||
|
||||
- **WHEN** 流式请求成功完成
|
||||
- **THEN** 网关 SHALL 增加该供应商和模型的请求计数
|
||||
- **THEN** 网关 SHALL 在流结束后记录统计
|
||||
- **THEN** 网关 SHALL 通过 StatsBuffer 增加该供应商和模型的请求计数
|
||||
- **THEN** 网关 SHALL 在流结束后异步记录统计
|
||||
|
||||
### Requirement: 使用统计记录统一模型标识
|
||||
|
||||
@@ -90,10 +91,9 @@
|
||||
#### Scenario: 并发请求
|
||||
|
||||
- **WHEN** 同时处理多个并发请求
|
||||
- **THEN** 网关 SHALL 正确为每个请求增加请求计数
|
||||
- **THEN** 网关 SHALL 使用原子操作正确增加每个请求的计数
|
||||
- **THEN** 不 SHALL 因并发写入而丢失统计
|
||||
|
||||
**变更说明:** 并发控制在 StatsRepository 中通过数据库事务实现。API 接口保持不变。
|
||||
- **THEN** SHALL 使用 StatsBuffer 的内存计数器
|
||||
|
||||
### Requirement: 使用 service 层处理业务逻辑
|
||||
|
||||
@@ -121,6 +121,13 @@ Service SHALL 通过 StatsRepository 访问数据。
|
||||
- **THEN** SHALL 调用对应的 StatsRepository 方法
|
||||
- **THEN** SHALL 使用 domain.UsageStats 类型
|
||||
|
||||
#### Scenario: 批量更新统计
|
||||
|
||||
- **WHEN** StatsBuffer 刷新统计
|
||||
- **THEN** SHALL 调用 StatsRepository.BatchUpdate
|
||||
- **THEN** SHALL 使用事务更新或创建统计记录
|
||||
- **THEN** SHALL 支持增量更新(request_count + delta)
|
||||
|
||||
#### Scenario: 事务处理
|
||||
|
||||
- **WHEN** 记录统计
|
||||
@@ -136,3 +143,36 @@ Service SHALL 通过 StatsRepository 访问数据。
|
||||
- **WHEN** 查询统计
|
||||
- **THEN** SHALL 使用 (provider_id, model_name, date) 复合索引
|
||||
- **THEN** SHALL 优化查询性能
|
||||
|
||||
### Requirement: 统计数据可接受少量丢失
|
||||
|
||||
统计记录方式改为内存缓冲,可接受少量丢失。
|
||||
|
||||
#### Scenario: 进程崩溃丢失统计
|
||||
|
||||
- **WHEN** 进程崩溃
|
||||
- **THEN** MAY 丢失最近 flushInterval 内的统计
|
||||
- **THEN** 统计数据用于监控,可接受少量丢失
|
||||
|
||||
#### Scenario: 优雅关闭保证不丢失
|
||||
|
||||
- **WHEN** 服务优雅关闭
|
||||
- **THEN** SHALL 调用 StatsBuffer.Stop
|
||||
- **THEN** SHALL 等待最后一次刷新完成
|
||||
- **THEN** SHALL 保证统计数据不丢失
|
||||
|
||||
### Requirement: StatsRepository 支持批量更新
|
||||
|
||||
StatsRepository SHALL 新增 BatchUpdate 方法支持批量增量更新。
|
||||
|
||||
#### Scenario: BatchUpdate 更新现有记录
|
||||
|
||||
- **WHEN** 调用 BatchUpdate 且当日记录存在
|
||||
- **THEN** SHALL 使用事务更新 request_count = request_count + delta
|
||||
- **THEN** SHALL 不创建新记录
|
||||
|
||||
#### Scenario: BatchUpdate 创建新记录
|
||||
|
||||
- **WHEN** 调用 BatchUpdate 且当日记录不存在
|
||||
- **THEN** SHALL 创建新记录,request_count = delta
|
||||
- **THEN** SHALL 使用事务保证原子性
|
||||
|
||||
Reference in New Issue
Block a user