feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。 核心变更: - 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID - 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束 - Repository 层:FindByProviderAndModelName、ListEnabled 方法 - Service 层:联合唯一校验、provider ID 字符集校验 - Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法 - Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合 - 新增 error-responses、unified-model-id 规范 测试覆盖: - 单元测试:modelid、conversion、handler、service、repository - 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换 - 迁移测试:UUID 主键、UNIQUE 约束、级联删除 OpenSpec: - 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id - 同步 11 个 delta specs 到 main specs - 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
@@ -49,17 +49,20 @@ func NewAppError(code, message string, httpStatus int) *AppError {
|
||||
|
||||
// Predefined errors
|
||||
var (
|
||||
ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound)
|
||||
ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound)
|
||||
ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound)
|
||||
ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound)
|
||||
ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest)
|
||||
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
|
||||
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
|
||||
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
|
||||
ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError)
|
||||
ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway)
|
||||
ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway)
|
||||
ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound)
|
||||
ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound)
|
||||
ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound)
|
||||
ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound)
|
||||
ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest)
|
||||
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
|
||||
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
|
||||
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
|
||||
ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError)
|
||||
ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway)
|
||||
ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway)
|
||||
ErrInvalidProviderID = NewAppError("invalid_provider_id", "供应商 ID 仅允许字母、数字、下划线,长度 1-64", http.StatusBadRequest)
|
||||
ErrDuplicateModel = NewAppError("duplicate_model", "同一供应商下模型名称已存在", http.StatusConflict)
|
||||
ErrImmutableField = NewAppError("immutable_field", "供应商 ID 不允许修改", http.StatusBadRequest)
|
||||
)
|
||||
|
||||
// AsAppError 尝试将 error 转换为 *AppError
|
||||
|
||||
@@ -9,6 +9,14 @@ import (
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// stdoutWriter 包装 os.Stdout,忽略 Sync() 错误。
|
||||
// 在非 TTY 环境(如 go test)中,os.Stdout 被重定向为 pipe,
|
||||
// 底层 fsync 会返回 "bad file descriptor"。zap 社区标准做法。
|
||||
type stdoutWriter struct{}
|
||||
|
||||
func (stdoutWriter) Write(p []byte) (int, error) { return os.Stdout.Write(p) }
|
||||
func (stdoutWriter) Sync() error { return nil }
|
||||
|
||||
// Config 日志配置
|
||||
type Config struct {
|
||||
Level string // 日志级别: debug, info, warn, error
|
||||
@@ -46,7 +54,7 @@ func New(cfg Config) (*zap.Logger, error) {
|
||||
|
||||
stdoutCore := zapcore.NewCore(
|
||||
stdoutEncoder,
|
||||
zapcore.AddSync(os.Stdout),
|
||||
zapcore.AddSync(stdoutWriter{}),
|
||||
level,
|
||||
)
|
||||
|
||||
|
||||
63
backend/pkg/modelid/model_id.go
Normal file
63
backend/pkg/modelid/model_id.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package modelid
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var providerIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
|
||||
// ParseUnifiedModelID 将 "provider_id/model_name" 格式的字符串解析为 providerID 和 modelName
|
||||
// 在第一个 "/" 处分割,model_name 可以包含 "/"
|
||||
func ParseUnifiedModelID(id string) (providerID, modelName string, err error) {
|
||||
if id == "" {
|
||||
return "", "", errors.New("统一模型 ID 不能为空")
|
||||
}
|
||||
|
||||
parts := strings.SplitN(id, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", errors.New("统一模型 ID 格式错误,缺少分隔符 \"/\"")
|
||||
}
|
||||
|
||||
providerID = parts[0]
|
||||
modelName = parts[1]
|
||||
|
||||
if providerID == "" {
|
||||
return "", "", errors.New("provider_id 不能为空")
|
||||
}
|
||||
if modelName == "" {
|
||||
return "", "", errors.New("model_name 不能为空")
|
||||
}
|
||||
|
||||
if !providerIDRegex.MatchString(providerID) {
|
||||
return "", "", errors.New("provider_id 仅允许字母、数字、下划线")
|
||||
}
|
||||
|
||||
return providerID, modelName, nil
|
||||
}
|
||||
|
||||
// FormatUnifiedModelID 将 providerID 和 modelName 组合格式化为统一模型 ID
|
||||
func FormatUnifiedModelID(providerID, modelName string) string {
|
||||
return providerID + "/" + modelName
|
||||
}
|
||||
|
||||
// ValidateProviderID 校验 providerID 仅包含字母、数字、下划线,长度 1-64
|
||||
func ValidateProviderID(id string) error {
|
||||
if id == "" {
|
||||
return errors.New("provider_id 不能为空")
|
||||
}
|
||||
if len(id) > 64 {
|
||||
return errors.New("provider_id 长度不能超过 64 个字符")
|
||||
}
|
||||
if !providerIDRegex.MatchString(id) {
|
||||
return errors.New("provider_id 仅允许字母、数字、下划线")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidUnifiedModelID 判断字符串是否为合法的统一模型 ID 格式
|
||||
func IsValidUnifiedModelID(id string) bool {
|
||||
_, _, err := ParseUnifiedModelID(id)
|
||||
return err == nil
|
||||
}
|
||||
96
backend/pkg/modelid/model_id_test.go
Normal file
96
backend/pkg/modelid/model_id_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package modelid
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseUnifiedModelID_StandardFormat(t *testing.T) {
|
||||
providerID, modelName, err := ParseUnifiedModelID("openai/gpt-4")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "openai", providerID)
|
||||
assert.Equal(t, "gpt-4", modelName)
|
||||
}
|
||||
|
||||
func TestParseUnifiedModelID_ModelNameWithSlashes(t *testing.T) {
|
||||
providerID, modelName, err := ParseUnifiedModelID("azure/accounts/org-123/models/gpt-4")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "azure", providerID)
|
||||
assert.Equal(t, "accounts/org-123/models/gpt-4", modelName)
|
||||
}
|
||||
|
||||
func TestParseUnifiedModelID_MissingSeparator(t *testing.T) {
|
||||
_, _, err := ParseUnifiedModelID("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseUnifiedModelID_EmptyString(t *testing.T) {
|
||||
_, _, err := ParseUnifiedModelID("")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseUnifiedModelID_OnlySeparator(t *testing.T) {
|
||||
tests := []string{"/model", "provider/", "/"}
|
||||
for _, tc := range tests {
|
||||
_, _, err := ParseUnifiedModelID(tc)
|
||||
assert.Error(t, err, "输入 %q 应返回错误", tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUnifiedModelID_InvalidProviderID(t *testing.T) {
|
||||
tests := []string{
|
||||
"open-ai/gpt-4",
|
||||
"open.ai/gpt-4",
|
||||
"供应商/gpt-4",
|
||||
"open ai/gpt-4",
|
||||
}
|
||||
for _, tc := range tests {
|
||||
_, _, err := ParseUnifiedModelID(tc)
|
||||
assert.Error(t, err, "providerID 含非法字符 %q 应返回错误", tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatUnifiedModelID(t *testing.T) {
|
||||
assert.Equal(t, "openai/gpt-4", FormatUnifiedModelID("openai", "gpt-4"))
|
||||
assert.Equal(t, "anthropic/claude-3-opus", FormatUnifiedModelID("anthropic", "claude-3-opus"))
|
||||
}
|
||||
|
||||
func TestValidateProviderID_Valid(t *testing.T) {
|
||||
validIDs := []string{"openai", "deep_seek", "provider01", "OpenAI"}
|
||||
for _, id := range validIDs {
|
||||
assert.NoError(t, ValidateProviderID(id), "%q 应校验通过", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProviderID_InvalidChars(t *testing.T) {
|
||||
invalidIDs := []string{"open-ai", "open.ai", "open ai", "供应商", "open/ai"}
|
||||
for _, id := range invalidIDs {
|
||||
assert.Error(t, ValidateProviderID(id), "%q 应校验失败", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProviderID_Empty(t *testing.T) {
|
||||
assert.Error(t, ValidateProviderID(""))
|
||||
}
|
||||
|
||||
func TestValidateProviderID_TooLong(t *testing.T) {
|
||||
longID := strings.Repeat("a", 65)
|
||||
assert.Error(t, ValidateProviderID(longID))
|
||||
|
||||
exactly64 := strings.Repeat("a", 64)
|
||||
assert.NoError(t, ValidateProviderID(exactly64))
|
||||
}
|
||||
|
||||
func TestIsValidUnifiedModelID(t *testing.T) {
|
||||
assert.True(t, IsValidUnifiedModelID("openai/gpt-4"))
|
||||
assert.True(t, IsValidUnifiedModelID("anthropic/claude-3-opus-20240229"))
|
||||
assert.True(t, IsValidUnifiedModelID("azure/accounts/org/models/gpt-4"))
|
||||
|
||||
assert.False(t, IsValidUnifiedModelID(""))
|
||||
assert.False(t, IsValidUnifiedModelID("gpt-4"))
|
||||
assert.False(t, IsValidUnifiedModelID("open-ai/gpt-4"))
|
||||
assert.False(t, IsValidUnifiedModelID("/model"))
|
||||
assert.False(t, IsValidUnifiedModelID("provider/"))
|
||||
}
|
||||
Reference in New Issue
Block a user