feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。 核心变更: - 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID - 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束 - Repository 层:FindByProviderAndModelName、ListEnabled 方法 - Service 层:联合唯一校验、provider ID 字符集校验 - Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法 - Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合 - 新增 error-responses、unified-model-id 规范 测试覆盖: - 单元测试:modelid、conversion、handler、service、repository - 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换 - 迁移测试:UUID 主键、UNIQUE 约束、级联删除 OpenSpec: - 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id - 同步 11 个 delta specs 到 main specs - 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
@@ -7,6 +7,7 @@ type ModelService interface {
|
||||
Create(model *domain.Model) error
|
||||
Get(id string) (*domain.Model, error)
|
||||
List(providerID string) ([]domain.Model, error)
|
||||
ListEnabled() ([]domain.Model, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
@@ -17,11 +18,18 @@ func NewModelService(modelRepo repository.ModelRepository, providerRepo reposito
|
||||
}
|
||||
|
||||
func (s *modelService) Create(model *domain.Model) error {
|
||||
// Verify provider exists
|
||||
_, err := s.providerRepo.GetByID(model.ProviderID)
|
||||
if err != nil {
|
||||
// 校验供应商存在
|
||||
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
// 联合唯一校验:同一供应商下 model_name 不重复
|
||||
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 自动生成 UUID 作为 id
|
||||
model.ID = uuid.New().String()
|
||||
model.Enabled = true
|
||||
return s.modelRepo.Create(model)
|
||||
}
|
||||
@@ -34,17 +42,57 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) {
|
||||
return s.modelRepo.List(providerID)
|
||||
}
|
||||
|
||||
func (s *modelService) ListEnabled() ([]domain.Model, error) {
|
||||
return s.modelRepo.ListEnabled()
|
||||
}
|
||||
|
||||
func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||
// If updating provider_id, verify new provider exists
|
||||
// 获取当前模型
|
||||
current, err := s.modelRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
// 如果更新 provider_id,校验新供应商存在
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
_, err := s.providerRepo.GetByID(providerID)
|
||||
if err != nil {
|
||||
if _, err := s.providerRepo.GetByID(providerID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
}
|
||||
|
||||
// 确定更新后的 provider_id 和 model_name
|
||||
newProviderID := current.ProviderID
|
||||
if v, ok := updates["provider_id"].(string); ok {
|
||||
newProviderID = v
|
||||
}
|
||||
newModelName := current.ModelName
|
||||
if v, ok := updates["model_name"].(string); ok {
|
||||
newModelName = v
|
||||
}
|
||||
|
||||
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
|
||||
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.modelRepo.Update(id, updates)
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(id string) error {
|
||||
return s.modelRepo.Delete(id)
|
||||
}
|
||||
|
||||
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
|
||||
// excludeID 用于更新时排除自身
|
||||
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
|
||||
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil // 未找到,不重复
|
||||
}
|
||||
if excludeID != "" && existing.ID == excludeID {
|
||||
return nil // 排除自身
|
||||
}
|
||||
return appErrors.ErrDuplicateModel
|
||||
}
|
||||
|
||||
@@ -9,4 +9,7 @@ type ProviderService interface {
|
||||
List() ([]domain.Provider, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
// 统一模型 ID 相关方法
|
||||
ListEnabledModels() ([]domain.Model, error)
|
||||
GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error)
|
||||
}
|
||||
|
||||
@@ -1,21 +1,35 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type providerService struct {
|
||||
providerRepo repository.ProviderRepository
|
||||
modelRepo repository.ModelRepository
|
||||
}
|
||||
|
||||
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo}
|
||||
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
|
||||
}
|
||||
|
||||
func (s *providerService) Create(provider *domain.Provider) error {
|
||||
// 校验 provider_id 字符集
|
||||
if err := modelid.ValidateProviderID(provider.ID); err != nil {
|
||||
return appErrors.ErrInvalidProviderID
|
||||
}
|
||||
provider.Enabled = true
|
||||
return s.providerRepo.Create(provider)
|
||||
err := s.providerRepo.Create(provider)
|
||||
if err != nil && isUniqueConstraintError(err) {
|
||||
return appErrors.ErrConflict
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
@@ -41,9 +55,31 @@ func (s *providerService) List() ([]domain.Provider, error) {
|
||||
}
|
||||
|
||||
func (s *providerService) Update(id string, updates map[string]interface{}) error {
|
||||
if _, ok := updates["id"]; ok {
|
||||
return appErrors.ErrImmutableField
|
||||
}
|
||||
return s.providerRepo.Update(id, updates)
|
||||
}
|
||||
|
||||
func (s *providerService) Delete(id string) error {
|
||||
return s.providerRepo.Delete(id)
|
||||
}
|
||||
|
||||
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
|
||||
func (s *providerService) ListEnabledModels() ([]domain.Model, error) {
|
||||
return s.modelRepo.ListEnabled()
|
||||
}
|
||||
|
||||
// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询)
|
||||
func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||
return s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
}
|
||||
|
||||
// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate")
|
||||
}
|
||||
|
||||
@@ -4,5 +4,5 @@ import "nex/backend/internal/domain"
|
||||
|
||||
// RoutingService 路由服务接口
|
||||
type RoutingService interface {
|
||||
Route(modelName string) (*domain.RouteResult, error)
|
||||
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ func NewRoutingService(modelRepo repository.ModelRepository, providerRepo reposi
|
||||
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
}
|
||||
|
||||
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
model, err := s.modelRepo.GetByModelName(modelName)
|
||||
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
func TestProviderService_Update(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
|
||||
|
||||
@@ -28,7 +29,8 @@ func TestProviderService_Update(t *testing.T) {
|
||||
func TestProviderService_Update_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
|
||||
assert.Error(t, err)
|
||||
@@ -41,11 +43,12 @@ func TestModelService_Get(t *testing.T) {
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
model, err := svc.Get("m1")
|
||||
result, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", model.ModelName)
|
||||
assert.Equal(t, "gpt-4", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelService_Update(t *testing.T) {
|
||||
@@ -55,14 +58,15 @@ func TestModelService_Update(t *testing.T) {
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"})
|
||||
err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := svc.Get("m1")
|
||||
result, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4o", model.ModelName)
|
||||
assert.Equal(t, "gpt-4o", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
||||
@@ -72,9 +76,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"})
|
||||
err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -85,12 +90,13 @@ func TestModelService_Delete(t *testing.T) {
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
err := svc.Delete("m1")
|
||||
err := svc.Delete(model.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = svc.Get("m1")
|
||||
_, err = svc.Get(model.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
@@ -29,80 +32,106 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// ============ ProviderService 测试 ============
|
||||
// ============ RoutingService - RouteByModelName 测试 ============
|
||||
|
||||
func TestProviderService_Create(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
|
||||
}
|
||||
err := svc.Create(provider)
|
||||
// 创建供应商和模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
result, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, provider.Enabled)
|
||||
assert.Equal(t, "openai", result.Provider.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||
}
|
||||
|
||||
func TestProviderService_Get_MaskKey(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
svc.Create(&domain.Provider{
|
||||
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
|
||||
})
|
||||
|
||||
result, err := svc.Get("p1", true)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "***2345", result.APIKey)
|
||||
|
||||
result, err = svc.Get("p1", false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
|
||||
_, err := svc.RouteByModelName("openai", "nonexistent-model")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
}
|
||||
|
||||
func TestProviderService_List(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
|
||||
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
|
||||
// 创建启用的供应商和禁用的模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
|
||||
providers, err := svc.List()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, providers, 2)
|
||||
assert.Contains(t, providers[0].APIKey, "***")
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
||||
}
|
||||
|
||||
func TestProviderService_Delete(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
err := svc.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
// 创建启用的供应商和模型,然后禁用供应商
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
|
||||
|
||||
_, err = svc.Get("p1", false)
|
||||
assert.Error(t, err)
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
||||
}
|
||||
|
||||
// ============ ModelService 测试 ============
|
||||
// ============ ModelService - Create with UUID 测试 ============
|
||||
|
||||
func TestModelService_Create(t *testing.T) {
|
||||
func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, model.Enabled)
|
||||
|
||||
// 验证返回的 model 拥有有效的 UUID
|
||||
assert.NotEmpty(t, model.ID)
|
||||
_, err = uuid.Parse(model.ID)
|
||||
assert.NoError(t, err, "model.ID should be a valid UUID")
|
||||
|
||||
// 通过 Get 验证持久化
|
||||
stored, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.ID, stored.ID)
|
||||
assert.Equal(t, "gpt-4", stored.ModelName)
|
||||
}
|
||||
|
||||
func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
|
||||
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err = svc.Create(model2)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
}
|
||||
|
||||
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
@@ -111,160 +140,135 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
|
||||
}
|
||||
|
||||
func TestModelService_List(t *testing.T) {
|
||||
// ============ ProviderService - Create with validation 测试 ============
|
||||
|
||||
func TestProviderService_Create_InvalidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
|
||||
}
|
||||
|
||||
func TestProviderService_Create_ValidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai", provider.ID)
|
||||
assert.True(t, provider.Enabled)
|
||||
}
|
||||
|
||||
// ============ ModelService - Update with duplicate check 测试 ============
|
||||
|
||||
func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
|
||||
|
||||
models, err := svc.List("p1")
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, models, 2)
|
||||
|
||||
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
|
||||
err = svc.Create(model2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
|
||||
err = svc.Update(model2.ID, map[string]interface{}{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
}
|
||||
|
||||
// ============ RoutingService 测试 ============
|
||||
|
||||
func TestRoutingService_Route(t *testing.T) {
|
||||
func TestModelService_Update_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
result, err := svc.Route("gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "p1", result.Provider.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||
err := svc.Update("nonexistent-id", map[string]interface{}{
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
|
||||
func TestModelService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
_, err := svc.Route("nonexistent-model")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
|
||||
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
// 先创建启用的模型,然后通过 Update 禁用
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
// 先创建启用的 provider,然后禁用
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ StatsService 测试 ============
|
||||
|
||||
func TestStatsService_RecordAndGet(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
err := svc.Record("p1", "gpt-4")
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats, err := svc.Get("p1", "", nil, nil)
|
||||
// 更新 model_name 为不冲突的值
|
||||
err = svc.Update(model.ID, map[string]interface{}{
|
||||
"model_name": "gpt-4-turbo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
|
||||
updated, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
// ============ ProviderService - Update immutable ID 测试 ============
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
|
||||
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
|
||||
}
|
||||
func TestProviderService_Update_ImmutableID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
result := svc.Aggregate(stats, "provider")
|
||||
assert.Len(t, result, 2)
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
p1Count := 0
|
||||
p2Count := 0
|
||||
for _, r := range result {
|
||||
if r["provider_id"] == "p1" {
|
||||
p1Count = r["request_count"].(int)
|
||||
}
|
||||
if r["provider_id"] == "p2" {
|
||||
p2Count = r["request_count"].(int)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 15, p1Count)
|
||||
assert.Equal(t, 8, p2Count)
|
||||
// 尝试更新 id 字段
|
||||
err = svc.Update("openai", map[string]interface{}{
|
||||
"id": "new-id",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
func TestProviderService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
{ProviderID: "p2", RequestCount: 5},
|
||||
}
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := svc.Aggregate(stats, "date")
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, 15, result[0]["request_count"])
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
|
||||
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
|
||||
}
|
||||
|
||||
result := svc.Aggregate(stats, "model")
|
||||
assert.Len(t, result, 3)
|
||||
|
||||
// 验证每个 provider/model 组合的计数
|
||||
counts := make(map[string]int)
|
||||
for _, r := range result {
|
||||
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
|
||||
counts[key] = r["request_count"].(int)
|
||||
}
|
||||
assert.Equal(t, 13, counts["openai/gpt-4"])
|
||||
assert.Equal(t, 5, counts["openai/gpt-3.5"])
|
||||
assert.Equal(t, 8, counts["anthropic/claude-3"])
|
||||
// 更新 name
|
||||
err = svc.Update("openai", map[string]interface{}{
|
||||
"name": "OpenAI Updated",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.Get("openai", false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OpenAI Updated", updated.Name)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user