package service import ( "strings" "sync" "sync/atomic" "time" "go.uber.org/zap" "nex/backend/internal/repository" ) type StatsBuffer struct { counters sync.Map flushInterval time.Duration flushThreshold int totalCount atomic.Int64 statsRepo repository.StatsRepository logger *zap.Logger stopCh chan struct{} doneCh chan struct{} } type StatsBufferOption func(*StatsBuffer) func WithFlushInterval(d time.Duration) StatsBufferOption { return func(b *StatsBuffer) { b.flushInterval = d } } func WithFlushThreshold(threshold int) StatsBufferOption { return func(b *StatsBuffer) { b.flushThreshold = threshold } } func NewStatsBuffer( statsRepo repository.StatsRepository, logger *zap.Logger, opts ...StatsBufferOption, ) *StatsBuffer { b := &StatsBuffer{ statsRepo: statsRepo, logger: logger, flushInterval: 5 * time.Second, flushThreshold: 100, stopCh: make(chan struct{}), doneCh: make(chan struct{}), } for _, opt := range opts { opt(b) } return b } func (b *StatsBuffer) Increment(providerID, modelName string) { today := time.Now().Format("2006-01-02") key := providerID + "/" + modelName + "/" + today var counter *int64 if v, ok := b.counters.Load(key); ok { counter = v.(*int64) } else { val := int64(0) counter = &val actual, loaded := b.counters.LoadOrStore(key, counter) if loaded { counter = actual.(*int64) } } atomic.AddInt64(counter, 1) if b.totalCount.Add(1) >= int64(b.flushThreshold) { go b.flush() } } func (b *StatsBuffer) Start() { go func() { ticker := time.NewTicker(b.flushInterval) defer ticker.Stop() for { select { case <-ticker.C: b.flush() case <-b.stopCh: b.flush() close(b.doneCh) return } } }() } func (b *StatsBuffer) Stop() { close(b.stopCh) <-b.doneCh } func (b *StatsBuffer) flush() { type statEntry struct { providerID string modelName string date string count int64 } var entries []statEntry b.counters.Range(func(key, value interface{}) bool { keyStr := key.(string) parts := strings.Split(keyStr, "/") if len(parts) != 3 { return true } counter := value.(*int64) count := atomic.SwapInt64(counter, 0) if count > 0 { entries = append(entries, statEntry{ providerID: parts[0], modelName: parts[1], date: parts[2], count: count, }) } return true }) if len(entries) == 0 { return } success := 0 for _, entry := range entries { date, _ := time.Parse("2006-01-02", entry.date) err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count)) if err != nil { b.logger.Error("批量更新统计失败", zap.String("provider_id", entry.providerID), zap.String("model_name", entry.modelName), zap.Int64("count", entry.count), zap.Error(err)) key := entry.providerID + "/" + entry.modelName + "/" + entry.date if v, ok := b.counters.Load(key); ok { counter := v.(*int64) atomic.AddInt64(counter, entry.count) } } else { success++ } } b.totalCount.Store(0) b.logger.Debug("统计刷新完成", zap.Int("total", len(entries)), zap.Int("success", success), zap.Int("failed", len(entries)-success)) }