1
0

feat: 实现统一模型 ID 机制

实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
2026-04-21 18:14:10 +08:00
parent 7f0f831226
commit 395887667d
73 changed files with 3360 additions and 1374 deletions

View File

@@ -7,6 +7,7 @@ type ModelService interface {
Create(model *domain.Model) error
Get(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error)
ListEnabled() ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
}

View File

@@ -1,6 +1,7 @@
package service
import (
"github.com/google/uuid"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
@@ -17,11 +18,18 @@ func NewModelService(modelRepo repository.ModelRepository, providerRepo reposito
}
func (s *modelService) Create(model *domain.Model) error {
// Verify provider exists
_, err := s.providerRepo.GetByID(model.ProviderID)
if err != nil {
// 校验供应商存在
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
return appErrors.ErrProviderNotFound
}
// 联合唯一校验:同一供应商下 model_name 不重复
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
return err
}
// 自动生成 UUID 作为 id
model.ID = uuid.New().String()
model.Enabled = true
return s.modelRepo.Create(model)
}
@@ -34,17 +42,57 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) {
return s.modelRepo.List(providerID)
}
func (s *modelService) ListEnabled() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
func (s *modelService) Update(id string, updates map[string]interface{}) error {
// If updating provider_id, verify new provider exists
// 获取当前模型
current, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
// 如果更新 provider_id校验新供应商存在
if providerID, ok := updates["provider_id"].(string); ok {
_, err := s.providerRepo.GetByID(providerID)
if err != nil {
if _, err := s.providerRepo.GetByID(providerID); err != nil {
return appErrors.ErrProviderNotFound
}
}
// 确定更新后的 provider_id 和 model_name
newProviderID := current.ProviderID
if v, ok := updates["provider_id"].(string); ok {
newProviderID = v
}
newModelName := current.ModelName
if v, ok := updates["model_name"].(string); ok {
newModelName = v
}
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
if newProviderID != current.ProviderID || newModelName != current.ModelName {
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
return err
}
}
return s.modelRepo.Update(id, updates)
}
func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id)
}
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
// excludeID 用于更新时排除自身
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil // 未找到,不重复
}
if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身
}
return appErrors.ErrDuplicateModel
}

View File

@@ -9,4 +9,7 @@ type ProviderService interface {
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error)
}

View File

@@ -1,21 +1,35 @@
package service
import (
"strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
)
type providerService struct {
providerRepo repository.ProviderRepository
modelRepo repository.ModelRepository
}
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
return &providerService{providerRepo: providerRepo}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
}
func (s *providerService) Create(provider *domain.Provider) error {
// 校验 provider_id 字符集
if err := modelid.ValidateProviderID(provider.ID); err != nil {
return appErrors.ErrInvalidProviderID
}
provider.Enabled = true
return s.providerRepo.Create(provider)
err := s.providerRepo.Create(provider)
if err != nil && isUniqueConstraintError(err) {
return appErrors.ErrConflict
}
return err
}
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
@@ -41,9 +55,31 @@ func (s *providerService) List() ([]domain.Provider, error) {
}
func (s *providerService) Update(id string, updates map[string]interface{}) error {
if _, ok := updates["id"]; ok {
return appErrors.ErrImmutableField
}
return s.providerRepo.Update(id, updates)
}
func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id)
}
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
func (s *providerService) ListEnabledModels() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询)
func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return s.modelRepo.FindByProviderAndModelName(providerID, modelName)
}
// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误
func isUniqueConstraintError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate")
}

View File

@@ -4,5 +4,5 @@ import "nex/backend/internal/domain"
// RoutingService 路由服务接口
type RoutingService interface {
Route(modelName string) (*domain.RouteResult, error)
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
}

View File

@@ -16,8 +16,8 @@ func NewRoutingService(modelRepo repository.ModelRepository, providerRepo reposi
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
}
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.GetByModelName(modelName)
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil, appErrors.ErrModelNotFound
}

View File

@@ -13,7 +13,8 @@ import (
func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
@@ -28,7 +29,8 @@ func TestProviderService_Update(t *testing.T) {
func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err)
@@ -41,11 +43,12 @@ func TestModelService_Get(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
model, err := svc.Get("m1")
result, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelService_Update(t *testing.T) {
@@ -55,14 +58,15 @@ func TestModelService_Update(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"})
err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"})
require.NoError(t, err)
model, err := svc.Get("m1")
result, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4o", model.ModelName)
assert.Equal(t, "gpt-4o", result.ModelName)
}
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
@@ -72,9 +76,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"})
err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"})
assert.Error(t, err)
}
@@ -85,12 +90,13 @@ func TestModelService_Delete(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Delete("m1")
err := svc.Delete(model.ID)
require.NoError(t, err)
_, err = svc.Get("m1")
_, err = svc.Get(model.ID)
assert.Error(t, err)
}

View File

@@ -1,8 +1,10 @@
package service
import (
"errors"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
@@ -11,6 +13,7 @@ import (
"nex/backend/internal/config"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
)
func setupServiceTestDB(t *testing.T) *gorm.DB {
@@ -29,80 +32,106 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
return db
}
// ============ ProviderService 测试 ============
// ============ RoutingService - RouteByModelName 测试 ============
func TestProviderService_Create(t *testing.T) {
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
provider := &domain.Provider{
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
}
err := svc.Create(provider)
// 创建供应商和模型
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
result, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err)
assert.True(t, provider.Enabled)
assert.Equal(t, "openai", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
}
func TestProviderService_Get_MaskKey(t *testing.T) {
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
})
result, err := svc.Get("p1", true)
require.NoError(t, err)
assert.Equal(t, "***2345", result.APIKey)
result, err = svc.Get("p1", false)
require.NoError(t, err)
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
_, err := svc.RouteByModelName("openai", "nonexistent-model")
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestProviderService_List(t *testing.T) {
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
// 创建启用的供应商和禁用的模型
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
providers, err := svc.List()
require.NoError(t, err)
assert.Len(t, providers, 2)
assert.Contains(t, providers[0].APIKey, "***")
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
}
func TestProviderService_Delete(t *testing.T) {
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
err := svc.Delete("p1")
require.NoError(t, err)
// 创建启用的供应商和模型,然后禁用供应商
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
_, err = svc.Get("p1", false)
assert.Error(t, err)
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
}
// ============ ModelService 测试 ============
// ============ ModelService - Create with UUID 测试 ============
func TestModelService_Create(t *testing.T) {
func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
assert.True(t, model.Enabled)
// 验证返回的 model 拥有有效的 UUID
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
// 通过 Get 验证持久化
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
assert.Equal(t, "gpt-4", stored.ModelName)
}
func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
func TestModelService_Create_ProviderNotFound(t *testing.T) {
@@ -111,160 +140,135 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
assert.Error(t, err)
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
}
func TestModelService_List(t *testing.T) {
// ============ ProviderService - Create with validation 测试 ============
func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
}
func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
assert.True(t, provider.Enabled)
}
// ============ ModelService - Update with duplicate check 测试 ============
func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
models, err := svc.List("p1")
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
assert.Len(t, models, 2)
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
err = svc.Create(model2)
require.NoError(t, err)
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
err = svc.Update(model2.ID, map[string]interface{}{
"provider_id": "openai",
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
}
// ============ RoutingService 测试 ============
func TestRoutingService_Route(t *testing.T) {
func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
result, err := svc.Route("gpt-4")
require.NoError(t, err)
assert.Equal(t, "p1", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
err := svc.Update("nonexistent-id", map[string]interface{}{
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
}
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc := NewModelService(modelRepo, providerRepo)
_, err := svc.Route("nonexistent-model")
assert.Error(t, err)
}
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
// 先创建启用的模型,然后通过 Update 禁用
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
// 先创建启用的 provider然后禁用
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
// ============ StatsService 测试 ============
func TestStatsService_RecordAndGet(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
err := svc.Record("p1", "gpt-4")
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
stats, err := svc.Get("p1", "", nil, nil)
// 更新 model_name 为不冲突的值
err = svc.Update(model.ID, map[string]interface{}{
"model_name": "gpt-4-turbo",
})
require.NoError(t, err)
assert.Len(t, stats, 1)
updated, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
}
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
// ============ ProviderService - Update immutable ID 测试 ============
stats := []domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
}
func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
result := svc.Aggregate(stats, "provider")
assert.Len(t, result, 2)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
p1Count := 0
p2Count := 0
for _, r := range result {
if r["provider_id"] == "p1" {
p1Count = r["request_count"].(int)
}
if r["provider_id"] == "p2" {
p2Count = r["request_count"].(int)
}
}
assert.Equal(t, 15, p1Count)
assert.Equal(t, 8, p2Count)
// 尝试更新 id 字段
err = svc.Update("openai", map[string]interface{}{
"id": "new-id",
})
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
}
func TestStatsService_Aggregate_ByDate(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
{ProviderID: "p2", RequestCount: 5},
}
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
result := svc.Aggregate(stats, "date")
assert.Len(t, result, 1)
assert.Equal(t, 15, result[0]["request_count"])
}
func TestStatsService_Aggregate_ByModel(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
}
result := svc.Aggregate(stats, "model")
assert.Len(t, result, 3)
// 验证每个 provider/model 组合的计数
counts := make(map[string]int)
for _, r := range result {
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
counts[key] = r["request_count"].(int)
}
assert.Equal(t, 13, counts["openai/gpt-4"])
assert.Equal(t, 5, counts["openai/gpt-3.5"])
assert.Equal(t, 8, counts["anthropic/claude-3"])
// 更新 name
err = svc.Update("openai", map[string]interface{}{
"name": "OpenAI Updated",
})
require.NoError(t, err)
updated, err := svc.Get("openai", false)
require.NoError(t, err)
assert.Equal(t, "OpenAI Updated", updated.Name)
}