|
|
|
|
@@ -3,14 +3,15 @@ package service
|
|
|
|
|
import (
|
|
|
|
|
"errors"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/google/uuid"
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
"gorm.io/driver/sqlite"
|
|
|
|
|
"gorm.io/gorm"
|
|
|
|
|
|
|
|
|
|
"nex/backend/internal/config"
|
|
|
|
|
testHelpers "nex/backend/tests"
|
|
|
|
|
|
|
|
|
|
"nex/backend/internal/domain"
|
|
|
|
|
"nex/backend/internal/repository"
|
|
|
|
|
appErrors "nex/backend/pkg/errors"
|
|
|
|
|
@@ -18,18 +19,7 @@ import (
|
|
|
|
|
|
|
|
|
|
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
|
|
|
|
t.Helper()
|
|
|
|
|
dir := t.TempDir()
|
|
|
|
|
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
t.Cleanup(func() {
|
|
|
|
|
sqlDB, _ := db.DB()
|
|
|
|
|
if sqlDB != nil {
|
|
|
|
|
sqlDB.Close()
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
return db
|
|
|
|
|
return testHelpers.SetupTestDB(t)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ============ RoutingService - RouteByModelName 测试 ============
|
|
|
|
|
@@ -40,9 +30,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewRoutingService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
// 创建供应商和模型
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
|
|
|
|
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
|
|
|
|
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
|
|
|
|
|
|
|
|
|
result, err := svc.RouteByModelName("openai", "gpt-4")
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
@@ -66,10 +55,9 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewRoutingService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
// 创建启用的供应商和禁用的模型
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
|
|
|
|
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
|
|
|
|
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
|
|
|
|
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
|
|
|
|
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
|
|
|
|
|
|
|
|
|
|
_, err := svc.RouteByModelName("openai", "gpt-4")
|
|
|
|
|
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
|
|
|
|
@@ -81,10 +69,9 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewRoutingService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
// 创建启用的供应商和模型,然后禁用供应商
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
|
|
|
|
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
|
|
|
|
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
|
|
|
|
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
|
|
|
|
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
|
|
|
|
|
|
|
|
|
|
_, err := svc.RouteByModelName("openai", "gpt-4")
|
|
|
|
|
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
|
|
|
|
@@ -98,7 +85,7 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewModelService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
|
|
|
|
|
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
|
|
|
err := svc.Create(model)
|
|
|
|
|
@@ -122,7 +109,7 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewModelService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
|
|
|
|
|
|
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
|
|
|
err := svc.Create(model1)
|
|
|
|
|
@@ -179,8 +166,8 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewModelService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
|
|
|
|
|
|
|
|
|
|
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
|
|
|
err := svc.Create(model1)
|
|
|
|
|
@@ -216,7 +203,7 @@ func TestModelService_Update_Success(t *testing.T) {
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewModelService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
|
|
|
|
|
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
|
|
|
err := svc.Create(model)
|
|
|
|
|
@@ -272,3 +259,223 @@ func TestProviderService_Update_Success(t *testing.T) {
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, "OpenAI Updated", updated.Name)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ============ StatsService - Aggregate ByModel 测试 ============
|
|
|
|
|
|
|
|
|
|
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
stats []domain.UsageStats
|
|
|
|
|
expected []map[string]interface{}
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "multiple providers with same model name",
|
|
|
|
|
stats: []domain.UsageStats{
|
|
|
|
|
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
|
|
|
|
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
|
|
|
|
|
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
|
|
|
|
|
},
|
|
|
|
|
expected: []map[string]interface{}{
|
|
|
|
|
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
|
|
|
|
|
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "empty providerID",
|
|
|
|
|
stats: []domain.UsageStats{
|
|
|
|
|
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
|
|
|
|
|
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
|
|
|
|
|
},
|
|
|
|
|
expected: []map[string]interface{}{
|
|
|
|
|
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "empty result set",
|
|
|
|
|
stats: []domain.UsageStats{},
|
|
|
|
|
expected: []map[string]interface{}{},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
db := setupServiceTestDB(t)
|
|
|
|
|
statsRepo := repository.NewStatsRepository(db)
|
|
|
|
|
svc := NewStatsService(statsRepo)
|
|
|
|
|
|
|
|
|
|
result := svc.Aggregate(tt.stats, "model")
|
|
|
|
|
|
|
|
|
|
assert.Len(t, result, len(tt.expected))
|
|
|
|
|
for _, exp := range tt.expected {
|
|
|
|
|
found := false
|
|
|
|
|
for _, r := range result {
|
|
|
|
|
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
|
|
|
|
|
assert.Equal(t, exp["request_count"], r["request_count"])
|
|
|
|
|
found = true
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, found, "expected result not found: %v", exp)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ============ StatsService - Aggregate ByDate 测试 ============
|
|
|
|
|
|
|
|
|
|
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
stats []domain.UsageStats
|
|
|
|
|
expected []map[string]interface{}
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "normal date grouping",
|
|
|
|
|
stats: []domain.UsageStats{
|
|
|
|
|
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
|
|
|
|
|
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
|
|
|
|
|
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
|
|
|
|
|
},
|
|
|
|
|
expected: []map[string]interface{}{
|
|
|
|
|
{"date": "2024-01-01", "request_count": 15},
|
|
|
|
|
{"date": "2024-01-02", "request_count": 20},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "zero-value time",
|
|
|
|
|
stats: []domain.UsageStats{
|
|
|
|
|
{Date: time.Time{}, RequestCount: 10},
|
|
|
|
|
{Date: time.Time{}, RequestCount: 5},
|
|
|
|
|
},
|
|
|
|
|
expected: []map[string]interface{}{
|
|
|
|
|
{"date": "0001-01-01", "request_count": 15},
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "empty result set",
|
|
|
|
|
stats: []domain.UsageStats{},
|
|
|
|
|
expected: []map[string]interface{}{},
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
db := setupServiceTestDB(t)
|
|
|
|
|
statsRepo := repository.NewStatsRepository(db)
|
|
|
|
|
svc := NewStatsService(statsRepo)
|
|
|
|
|
|
|
|
|
|
result := svc.Aggregate(tt.stats, "date")
|
|
|
|
|
|
|
|
|
|
assert.Len(t, result, len(tt.expected))
|
|
|
|
|
for _, exp := range tt.expected {
|
|
|
|
|
found := false
|
|
|
|
|
for _, r := range result {
|
|
|
|
|
if r["date"] == exp["date"] {
|
|
|
|
|
assert.Equal(t, exp["request_count"], r["request_count"])
|
|
|
|
|
found = true
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.True(t, found, "expected result not found: %v", exp)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ============ ProviderService - isUniqueConstraintError 测试 ============
|
|
|
|
|
|
|
|
|
|
func TestProviderService_isUniqueConstraintError(t *testing.T) {
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
err error
|
|
|
|
|
expected bool
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "nil error",
|
|
|
|
|
err: nil,
|
|
|
|
|
expected: false,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "UNIQUE constraint failed",
|
|
|
|
|
err: errors.New("UNIQUE constraint failed"),
|
|
|
|
|
expected: true,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "duplicate key value",
|
|
|
|
|
err: errors.New("duplicate key value"),
|
|
|
|
|
expected: true,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "UNIQUE constraint case insensitive",
|
|
|
|
|
err: errors.New("unique constraint violation"),
|
|
|
|
|
expected: true,
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "other error",
|
|
|
|
|
err: errors.New("some other error"),
|
|
|
|
|
expected: false,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
result := isUniqueConstraintError(tt.err)
|
|
|
|
|
assert.Equal(t, tt.expected, result)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ============ ProviderService - List MaskAPIKey 测试 ============
|
|
|
|
|
|
|
|
|
|
func TestProviderService_List_MaskAPIKey(t *testing.T) {
|
|
|
|
|
db := setupServiceTestDB(t)
|
|
|
|
|
repo := repository.NewProviderRepository(db)
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewProviderService(repo, modelRepo)
|
|
|
|
|
|
|
|
|
|
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
|
|
|
|
|
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
|
|
|
|
|
require.NoError(t, svc.Create(provider1))
|
|
|
|
|
require.NoError(t, svc.Create(provider2))
|
|
|
|
|
|
|
|
|
|
providers, err := svc.List()
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
require.Len(t, providers, 2)
|
|
|
|
|
|
|
|
|
|
for _, p := range providers {
|
|
|
|
|
assert.Contains(t, p.APIKey, "***")
|
|
|
|
|
assert.Len(t, p.APIKey, 7)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestModelService_ConcurrentCreate(t *testing.T) {
|
|
|
|
|
db := setupServiceTestDB(t)
|
|
|
|
|
providerRepo := repository.NewProviderRepository(db)
|
|
|
|
|
modelRepo := repository.NewModelRepository(db)
|
|
|
|
|
svc := NewModelService(modelRepo, providerRepo)
|
|
|
|
|
|
|
|
|
|
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
|
|
|
|
|
|
|
|
|
results := make(chan error, 2)
|
|
|
|
|
for i := 0; i < 2; i++ {
|
|
|
|
|
go func() {
|
|
|
|
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
|
|
|
|
results <- svc.Create(model)
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err1 := <-results
|
|
|
|
|
err2 := <-results
|
|
|
|
|
|
|
|
|
|
successCount := 0
|
|
|
|
|
errorCount := 0
|
|
|
|
|
for _, err := range []error{err1, err2} {
|
|
|
|
|
if err == nil {
|
|
|
|
|
successCount++
|
|
|
|
|
} else {
|
|
|
|
|
errorCount++
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.Equal(t, 1, successCount)
|
|
|
|
|
assert.Equal(t, 1, errorCount)
|
|
|
|
|
}
|
|
|
|
|
|