feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。 核心变更: - 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID - 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束 - Repository 层:FindByProviderAndModelName、ListEnabled 方法 - Service 层:联合唯一校验、provider ID 字符集校验 - Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法 - Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合 - 新增 error-responses、unified-model-id 规范 测试覆盖: - 单元测试:modelid、conversion、handler、service、repository - 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换 - 迁移测试:UUID 主键、UNIQUE 约束、级联删除 OpenSpec: - 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id - 同步 11 个 delta specs 到 main specs - 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
@@ -29,80 +32,106 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// ============ ProviderService 测试 ============
|
||||
// ============ RoutingService - RouteByModelName 测试 ============
|
||||
|
||||
func TestProviderService_Create(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
|
||||
}
|
||||
err := svc.Create(provider)
|
||||
// 创建供应商和模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
result, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, provider.Enabled)
|
||||
assert.Equal(t, "openai", result.Provider.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||
}
|
||||
|
||||
func TestProviderService_Get_MaskKey(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
svc.Create(&domain.Provider{
|
||||
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
|
||||
})
|
||||
|
||||
result, err := svc.Get("p1", true)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "***2345", result.APIKey)
|
||||
|
||||
result, err = svc.Get("p1", false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
|
||||
_, err := svc.RouteByModelName("openai", "nonexistent-model")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
}
|
||||
|
||||
func TestProviderService_List(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
|
||||
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
|
||||
// 创建启用的供应商和禁用的模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
|
||||
providers, err := svc.List()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, providers, 2)
|
||||
assert.Contains(t, providers[0].APIKey, "***")
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
||||
}
|
||||
|
||||
func TestProviderService_Delete(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
err := svc.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
// 创建启用的供应商和模型,然后禁用供应商
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
|
||||
|
||||
_, err = svc.Get("p1", false)
|
||||
assert.Error(t, err)
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
||||
}
|
||||
|
||||
// ============ ModelService 测试 ============
|
||||
// ============ ModelService - Create with UUID 测试 ============
|
||||
|
||||
func TestModelService_Create(t *testing.T) {
|
||||
func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, model.Enabled)
|
||||
|
||||
// 验证返回的 model 拥有有效的 UUID
|
||||
assert.NotEmpty(t, model.ID)
|
||||
_, err = uuid.Parse(model.ID)
|
||||
assert.NoError(t, err, "model.ID should be a valid UUID")
|
||||
|
||||
// 通过 Get 验证持久化
|
||||
stored, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.ID, stored.ID)
|
||||
assert.Equal(t, "gpt-4", stored.ModelName)
|
||||
}
|
||||
|
||||
func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
|
||||
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err = svc.Create(model2)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
}
|
||||
|
||||
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
@@ -111,160 +140,135 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
|
||||
}
|
||||
|
||||
func TestModelService_List(t *testing.T) {
|
||||
// ============ ProviderService - Create with validation 测试 ============
|
||||
|
||||
func TestProviderService_Create_InvalidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
|
||||
}
|
||||
|
||||
func TestProviderService_Create_ValidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai", provider.ID)
|
||||
assert.True(t, provider.Enabled)
|
||||
}
|
||||
|
||||
// ============ ModelService - Update with duplicate check 测试 ============
|
||||
|
||||
func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
|
||||
|
||||
models, err := svc.List("p1")
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, models, 2)
|
||||
|
||||
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
|
||||
err = svc.Create(model2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
|
||||
err = svc.Update(model2.ID, map[string]interface{}{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
}
|
||||
|
||||
// ============ RoutingService 测试 ============
|
||||
|
||||
func TestRoutingService_Route(t *testing.T) {
|
||||
func TestModelService_Update_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
result, err := svc.Route("gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "p1", result.Provider.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||
err := svc.Update("nonexistent-id", map[string]interface{}{
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
|
||||
func TestModelService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
|
||||
_, err := svc.Route("nonexistent-model")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
|
||||
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
// 先创建启用的模型,然后通过 Update 禁用
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
// 先创建启用的 provider,然后禁用
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ StatsService 测试 ============
|
||||
|
||||
func TestStatsService_RecordAndGet(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
err := svc.Record("p1", "gpt-4")
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats, err := svc.Get("p1", "", nil, nil)
|
||||
// 更新 model_name 为不冲突的值
|
||||
err = svc.Update(model.ID, map[string]interface{}{
|
||||
"model_name": "gpt-4-turbo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
|
||||
updated, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
// ============ ProviderService - Update immutable ID 测试 ============
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
|
||||
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
|
||||
}
|
||||
func TestProviderService_Update_ImmutableID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
result := svc.Aggregate(stats, "provider")
|
||||
assert.Len(t, result, 2)
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
p1Count := 0
|
||||
p2Count := 0
|
||||
for _, r := range result {
|
||||
if r["provider_id"] == "p1" {
|
||||
p1Count = r["request_count"].(int)
|
||||
}
|
||||
if r["provider_id"] == "p2" {
|
||||
p2Count = r["request_count"].(int)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 15, p1Count)
|
||||
assert.Equal(t, 8, p2Count)
|
||||
// 尝试更新 id 字段
|
||||
err = svc.Update("openai", map[string]interface{}{
|
||||
"id": "new-id",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
func TestProviderService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
{ProviderID: "p2", RequestCount: 5},
|
||||
}
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := svc.Aggregate(stats, "date")
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, 15, result[0]["request_count"])
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
|
||||
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
|
||||
}
|
||||
|
||||
result := svc.Aggregate(stats, "model")
|
||||
assert.Len(t, result, 3)
|
||||
|
||||
// 验证每个 provider/model 组合的计数
|
||||
counts := make(map[string]int)
|
||||
for _, r := range result {
|
||||
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
|
||||
counts[key] = r["request_count"].(int)
|
||||
}
|
||||
assert.Equal(t, 13, counts["openai/gpt-4"])
|
||||
assert.Equal(t, 5, counts["openai/gpt-3.5"])
|
||||
assert.Equal(t, 8, counts["anthropic/claude-3"])
|
||||
// 更新 name
|
||||
err = svc.Update("openai", map[string]interface{}{
|
||||
"name": "OpenAI Updated",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.Get("openai", false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OpenAI Updated", updated.Name)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user