diff --git a/.gitignore b/.gitignore index 3b8b772..b073885 100644 --- a/.gitignore +++ b/.gitignore @@ -404,4 +404,5 @@ cython_debug/ openspec/changes/archive temp .agents -skills-lock.json \ No newline at end of file +skills-lock.json +.worktrees \ No newline at end of file diff --git a/README.md b/README.md index a706a2b..946e6e2 100644 --- a/README.md +++ b/README.md @@ -38,10 +38,11 @@ nex/ ## 功能特性 - **双协议支持**:同时支持 OpenAI 和 Anthropic 协议 -- **透明代理**:对 OpenAI 兼容供应商透传请求 +- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`) +- **透明代理**:对 OpenAI 兼容供应商 Smart Passthrough,最小化改写保持参数保真 - **流式响应**:完整支持 SSE 流式传输 - **Function Calling**:支持工具调用(Tools) -- **多供应商管理**:配置和管理多个供应商 +- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线) - **用量统计**:按供应商、模型、日期统计请求数量 - **Web 配置界面**:提供供应商和模型配置管理 @@ -99,23 +100,27 @@ bun dev ### 代理接口(对外部应用) +代理接口统一使用 `provider_id/model_name` 格式的模型 ID(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。 + - `POST /v1/chat/completions` - OpenAI Chat Completions API - `POST /v1/messages` - Anthropic Messages API +- `GET /v1/models` - 模型列表(本地数据库聚合,不请求上游) +- `GET /v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询) ### 管理接口(对前端) #### 供应商管理 - `GET /api/providers` - 列出所有供应商 -- `POST /api/providers` - 创建供应商 +- `POST /api/providers` - 创建供应商(`id` 仅限字母、数字、下划线,长度 1-64) - `GET /api/providers/:id` - 获取供应商 -- `PUT /api/providers/:id` - 更新供应商 +- `PUT /api/providers/:id` - 更新供应商(`id` 不可修改) - `DELETE /api/providers/:id` - 删除供应商 #### 模型管理 - `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤) -- `POST /api/models` - 创建模型 -- `GET /api/models/:id` - 获取模型 -- `PUT /api/models/:id` - 更新模型 +- `POST /api/models` - 创建模型(`id` 由系统自动生成 UUID,`provider_id` + `model_name` 联合唯一) +- `GET /api/models/:id` - 获取模型(响应含 `unified_id` 字段,格式 `provider_id/model_name`) +- `PUT /api/models/:id` - 更新模型(不可修改 `id`) - `DELETE /api/models/:id` - 删除模型 #### 统计查询 diff --git a/backend/README.md b/backend/README.md index b2fcc00..7e8990f 100644 --- a/backend/README.md +++ b/backend/README.md @@ -108,12 +108,13 @@ backend/ │ │ ├── logger.go │ │ ├── rotate.go │ │ └── context.go +│ ├── modelid/ # 统一模型 ID 工具包 +│ │ ├── model_id.go +│ │ └── model_id_test.go │ └── validator/ # 验证器 │ └── validator.go ├── migrations/ # 数据库迁移 -│ ├── 20260401000001_initial_schema.sql -│ ├── 20260401000002_add_indexes.sql -│ └── 20260419000001_add_provider_protocol.sql +│ └── 20260421000001_initial_schema.sql ├── tests/ # 集成测试 │ ├── helpers.go │ └── integration/ @@ -292,6 +293,8 @@ GET /anthropic/v1/models **协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。 +**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。 + ### 管理接口 #### 供应商管理 @@ -324,14 +327,30 @@ GET /anthropic/v1/models - `PUT /api/models/:id` - 更新模型 - `DELETE /api/models/:id` - 删除模型 +**创建请求**(id 由系统自动生成 UUID): + ```json { - "id": "gpt-4", "provider_id": "openai", "model_name": "gpt-4" } ``` +**响应示例**: + +```json +{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "provider_id": "openai", + "model_name": "gpt-4", + "unified_id": "openai/gpt-4", + "enabled": true, + "created_at": "2026-04-21T00:00:00Z" +} +``` + +**统一模型 ID**:`unified_id` 字段为 `provider_id/model_name` 格式,用于代理请求的 `model` 参数。 + #### 统计查询 - `GET /api/stats` - 查询统计 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 588ec44..103eda4 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -68,7 +68,7 @@ func main() { statsRepo := repository.NewStatsRepository(db) // 5. 初始化 service 层 - providerService := service.NewProviderService(providerRepo) + providerService := service.NewProviderService(providerRepo, modelRepo) modelService := service.NewModelService(modelRepo, providerRepo) routingService := service.NewRoutingService(modelRepo, providerRepo) statsService := service.NewStatsService(statsRepo) diff --git a/backend/internal/config/models.go b/backend/internal/config/models.go index 9656937..739e5c7 100644 --- a/backend/internal/config/models.go +++ b/backend/internal/config/models.go @@ -17,11 +17,11 @@ type Provider struct { Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"` } -// Model 模型配置 +// Model 模型配置(id 为 UUID 自动生成,UNIQUE(provider_id, model_name)) type Model struct { ID string `gorm:"primaryKey" json:"id"` - ProviderID string `gorm:"not null;index" json:"provider_id"` - ModelName string `gorm:"not null;index" json:"model_name"` + ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"` + ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"` Enabled bool `gorm:"default:true" json:"enabled"` CreatedAt time.Time `json:"created_at"` } diff --git a/backend/internal/conversion/adapter.go b/backend/internal/conversion/adapter.go index 7f5f5bd..0b9e0d8 100644 --- a/backend/internal/conversion/adapter.go +++ b/backend/internal/conversion/adapter.go @@ -40,6 +40,12 @@ type ProtocolAdapter interface { EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) + + // 统一模型 ID 相关方法 + ExtractUnifiedModelID(nativePath string) (string, error) + ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) + RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) + RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) } // AdapterRegistry 适配器注册表接口 diff --git a/backend/internal/conversion/anthropic/adapter.go b/backend/internal/conversion/anthropic/adapter.go index 9ae6c17..74053be 100644 --- a/backend/internal/conversion/anthropic/adapter.go +++ b/backend/internal/conversion/anthropic/adapter.go @@ -2,6 +2,7 @@ package anthropic import ( "encoding/json" + "fmt" "strings" "nex/backend/internal/conversion" @@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp } } -// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}) +// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /) func isModelInfoPath(path string) bool { if !strings.HasPrefix(path, "/v1/models/") { return false } suffix := path[len("/v1/models/"):] - return suffix != "" && !strings.Contains(suffix, "/") + return suffix != "" } // BuildUrl 根据接口类型构建 URL @@ -203,3 +204,74 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口") } + +// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name}) +func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) { + if !strings.HasPrefix(nativePath, "/v1/models/") { + return "", fmt.Errorf("不是模型详情路径: %s", nativePath) + } + suffix := nativePath[len("/v1/models/"):] + if suffix == "" { + return "", fmt.Errorf("路径缺少模型 ID") + } + return suffix, nil +} + +// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数 +func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) { + var m map[string]json.RawMessage + if err := json.Unmarshal(body, &m); err != nil { + return "", nil, err + } + + switch ifaceType { + case conversion.InterfaceTypeChat: + raw, exists := m["model"] + if !exists { + return "", nil, fmt.Errorf("请求体中缺少 model 字段") + } + var current string + if err := json.Unmarshal(raw, ¤t); err != nil { + return "", nil, err + } + rewriteFunc := func(newModel string) ([]byte, error) { + m["model"], _ = json.Marshal(newModel) + return json.Marshal(m) + } + return current, rewriteFunc, nil + default: + return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType) + } +} + +// ExtractModelName 从请求体中提取 model 值 +func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) { + model, _, err := locateModelFieldInRequest(body, ifaceType) + return model, err +} + +// RewriteRequestModelName 最小化改写请求体中的 model 字段 +func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) { + _, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType) + if err != nil { + return nil, err + } + return rewriteFunc(newModel) +} + +// RewriteResponseModelName 最小化改写响应体中的 model 字段 +func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) { + var m map[string]json.RawMessage + if err := json.Unmarshal(body, &m); err != nil { + return nil, err + } + + switch ifaceType { + case conversion.InterfaceTypeChat: + // Chat 响应必须有 model 字段,存在则改写,不存在则添加 + m["model"], _ = json.Marshal(newModel) + return json.Marshal(m) + default: + return body, nil + } +} diff --git a/backend/internal/conversion/anthropic/adapter_unified_test.go b/backend/internal/conversion/anthropic/adapter_unified_test.go new file mode 100644 index 0000000..f243cc7 --- /dev/null +++ b/backend/internal/conversion/anthropic/adapter_unified_test.go @@ -0,0 +1,263 @@ +package anthropic + +import ( + "encoding/json" + "testing" + + "nex/backend/internal/conversion" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ExtractUnifiedModelID +// --------------------------------------------------------------------------- + +func TestExtractUnifiedModelID(t *testing.T) { + a := NewAdapter() + + t.Run("standard_path", func(t *testing.T) { + id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3") + require.NoError(t, err) + assert.Equal(t, "anthropic/claude-3", id) + }) + + t.Run("multi_segment_path", func(t *testing.T) { + id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model") + require.NoError(t, err) + assert.Equal(t, "some/deep/nested/model", id) + }) + + t.Run("single_segment", func(t *testing.T) { + id, err := a.ExtractUnifiedModelID("/v1/models/claude-3") + require.NoError(t, err) + assert.Equal(t, "claude-3", id) + }) + + t.Run("non_model_path", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/messages") + require.Error(t, err) + }) + + t.Run("empty_suffix", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/models/") + require.Error(t, err) + }) + + t.Run("models_list_no_slash", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/models") + require.Error(t, err) + }) + + t.Run("unrelated_path", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/other") + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// ExtractModelName (Chat only for Anthropic) +// --------------------------------------------------------------------------- + +func TestExtractModelName(t *testing.T) { + a := NewAdapter() + + t.Run("chat", func(t *testing.T) { + body := []byte(`{"model":"anthropic/claude-3","messages":[]}`) + model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, "anthropic/claude-3", model) + }) + + t.Run("chat_with_max_tokens", func(t *testing.T) { + body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`) + model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, "anthropic/claude-3-opus", model) + }) + + t.Run("no_model_field", func(t *testing.T) { + body := []byte(`{"messages":[]}`) + _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("invalid_json", func(t *testing.T) { + body := []byte(`{invalid}`) + _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("unsupported_interface_type_embedding", func(t *testing.T) { + body := []byte(`{"model":"anthropic/claude-3"}`) + _, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings) + require.Error(t, err) + }) + + t.Run("unsupported_interface_type_rerank", func(t *testing.T) { + body := []byte(`{"model":"anthropic/claude-3"}`) + _, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank) + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// RewriteRequestModelName (Chat only for Anthropic) +// --------------------------------------------------------------------------- + +func TestRewriteRequestModelName(t *testing.T) { + a := NewAdapter() + + t.Run("chat", func(t *testing.T) { + body := []byte(`{"model":"anthropic/claude-3","messages":[]}`) + rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "claude-3", m["model"]) + + msgs, ok := m["messages"] + require.True(t, ok) + msgsArr, ok := msgs.([]interface{}) + require.True(t, ok) + assert.Len(t, msgsArr, 0) + }) + + t.Run("preserves_unknown_fields", func(t *testing.T) { + body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`) + rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "claude-3", m["model"]) + assert.Equal(t, 0.7, m["temperature"]) + + // max_tokens is encoded as float in JSON numbers + maxTokens, ok := m["max_tokens"] + require.True(t, ok) + assert.Equal(t, float64(1024), maxTokens) + }) + + t.Run("no_model_field", func(t *testing.T) { + body := []byte(`{"messages":[]}`) + _, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("invalid_json", func(t *testing.T) { + body := []byte(`{invalid}`) + _, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("unsupported_interface_type", func(t *testing.T) { + body := []byte(`{"model":"anthropic/claude-3"}`) + _, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings) + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// RewriteResponseModelName (Chat only for Anthropic) +// --------------------------------------------------------------------------- + +func TestRewriteResponseModelName(t *testing.T) { + a := NewAdapter() + + t.Run("chat_existing_model", func(t *testing.T) { + body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`) + rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "anthropic/claude-3", m["model"]) + + // other fields preserved + _, hasContent := m["content"] + assert.True(t, hasContent) + assert.Equal(t, "end_turn", m["stop_reason"]) + }) + + t.Run("chat_without_model_field_adds_it", func(t *testing.T) { + body := []byte(`{"content":[],"stop_reason":"end_turn"}`) + rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "anthropic/claude-3", m["model"]) + }) + + t.Run("passthrough_returns_body_unchanged", func(t *testing.T) { + body := []byte(`{"model":"claude-3"}`) + rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough) + require.NoError(t, err) + assert.Equal(t, string(body), string(rewritten)) + }) + + t.Run("invalid_json", func(t *testing.T) { + body := []byte(`{invalid}`) + _, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat) + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// ExtractModelName and RewriteRequest consistency +// --------------------------------------------------------------------------- + +func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) { + a := NewAdapter() + + t.Run("chat_round_trip", func(t *testing.T) { + original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`) + + // Extract the unified model ID from the body + extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, "anthropic/claude-3", extracted) + + // Rewrite to the native model name + rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat) + require.NoError(t, err) + + // Extract again from the rewritten body to verify the same location was targeted + afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, "claude-3", afterRewrite) + + // Verify other fields are preserved + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, float64(1024), m["max_tokens"]) + }) +} + +// --------------------------------------------------------------------------- +// isModelInfoPath (additional unified model ID cases) +// --------------------------------------------------------------------------- + +func TestIsModelInfoPath_UnifiedModelID(t *testing.T) { + tests := []struct { + name string + path string + expected bool + }{ + {"simple_model_id", "/v1/models/claude-3", true}, + {"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true}, + {"models_list", "/v1/models", false}, + {"models_list_trailing_slash", "/v1/models/", false}, + {"messages_path", "/v1/messages", false}, + {"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isModelInfoPath(tt.path)) + }) + } +} diff --git a/backend/internal/conversion/engine.go b/backend/internal/conversion/engine.go index 2301705..0dd59a1 100644 --- a/backend/internal/conversion/engine.go +++ b/backend/internal/conversion/engine.go @@ -79,11 +79,29 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc if err != nil { return nil, err } + + // Smart Passthrough: 同协议时最小化改写 model 字段 + interfaceType := providerAdapter.DetectInterfaceType(nativePath) + rewrittenBody := spec.Body + + // 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段 + if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank { + if len(spec.Body) > 0 && provider.ModelName != "" { + rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType) + if err != nil { + e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体", + zap.String("error", err.Error()), + zap.String("interface", string(interfaceType))) + rewrittenBody = spec.Body + } + } + } + return &HTTPRequestSpec{ URL: provider.BaseURL + nativePath, Method: spec.Method, Headers: providerAdapter.BuildHeaders(provider), - Body: spec.Body, + Body: rewrittenBody, }, nil } @@ -112,9 +130,30 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc }, nil } -// ConvertHttpResponse 转换 HTTP 响应 -func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) { +// ConvertHttpResponse 转换 HTTP 响应,modelOverride 用于跨协议场景覆写 model 字段 +func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) { if e.IsPassthrough(clientProtocol, providerProtocol) { + // Smart Passthrough: 同协议时最小化改写 model 字段 + if modelOverride != "" && len(spec.Body) > 0 { + adapter, err := e.registry.Get(clientProtocol) + if err != nil { + return &spec, nil + } + + rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType) + if err != nil { + e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体", + zap.String("error", err.Error()), + zap.String("interface", string(interfaceType))) + return &spec, nil + } + + return &HTTPResponseSpec{ + StatusCode: spec.StatusCode, + Headers: spec.Headers, + Body: rewrittenBody, + }, nil + } return &spec, nil } @@ -127,7 +166,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt return nil, err } - convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body) + convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride) if err != nil { return nil, err } @@ -139,9 +178,17 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt }, nil } -// CreateStreamConverter 创建流式转换器 -func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) { +// CreateStreamConverter 创建流式转换器,modelOverride 用于跨协议场景覆写 model 字段 +func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) { if e.IsPassthrough(clientProtocol, providerProtocol) { + // Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段 + if modelOverride != "" { + adapter, err := e.registry.Get(clientProtocol) + if err != nil { + return NewPassthroughStreamConverter(), nil + } + return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil + } return NewPassthroughStreamConverter(), nil } @@ -167,6 +214,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco ctx, clientProtocol, providerProtocol, + modelOverride, ), nil } @@ -192,11 +240,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte } } -// convertResponseBody 转换响应体 -func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { +// convertResponseBody 转换响应体,modelOverride 非空时在 canonical 层面覆写 Model 字段 +func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { switch interfaceType { case InterfaceTypeChat: - return e.convertChatResponseBody(clientAdapter, providerAdapter, body) + return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride) case InterfaceTypeModels: if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) { return body, nil @@ -211,12 +259,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) { return body, nil } - return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body) + return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride) case InterfaceTypeRerank: if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) { return body, nil } - return e.convertRerankResponseBody(clientAdapter, providerAdapter, body) + return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride) default: return body, nil } @@ -241,11 +289,14 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc return encoded, nil } -func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { +func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { canonicalResp, err := providerAdapter.DecodeResponse(body) if err != nil { return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err) } + if modelOverride != "" { + canonicalResp.Model = modelOverride + } encoded, err := clientAdapter.EncodeResponse(canonicalResp) if err != nil { return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err) @@ -290,12 +341,15 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P return providerAdapter.EncodeEmbeddingRequest(req, provider) } -func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { +func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { resp, err := providerAdapter.DecodeEmbeddingResponse(body) if err != nil { e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) return body, nil } + if modelOverride != "" { + resp.Model = modelOverride + } return clientAdapter.EncodeEmbeddingResponse(resp) } @@ -308,11 +362,14 @@ func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter Prot return providerAdapter.EncodeRerankRequest(req, provider) } -func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { +func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { resp, err := providerAdapter.DecodeRerankResponse(body) if err != nil { return body, nil } + if modelOverride != "" { + resp.Model = modelOverride + } return clientAdapter.EncodeRerankResponse(resp) } diff --git a/backend/internal/conversion/engine_supplemental_test.go b/backend/internal/conversion/engine_supplemental_test.go index b9c12c8..4b3dd1e 100644 --- a/backend/internal/conversion/engine_supplemental_test.go +++ b/backend/internal/conversion/engine_supplemental_test.go @@ -113,7 +113,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) { result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ StatusCode: 200, Body: []byte(`{"id":"resp-1"}`), - }, "client", "provider", InterfaceTypeChat) + }, "client", "provider", InterfaceTypeChat, "") require.NoError(t, err) assert.Equal(t, 200, result.StatusCode) assert.Contains(t, string(result.Body), "resp-1") @@ -129,7 +129,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) { _ = engine.RegisterAdapter(providerAdapter) _ = engine.RegisterAdapter(newMockAdapter("client", false)) - _, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat) + _, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "") assert.Error(t, err) } @@ -189,7 +189,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) { result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`), - }, "client", "provider", InterfaceTypeEmbeddings) + }, "client", "provider", InterfaceTypeEmbeddings, "") require.NoError(t, err) assert.NotNil(t, result) } @@ -207,7 +207,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) { result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`), - }, "client", "provider", InterfaceTypeRerank) + }, "client", "provider", InterfaceTypeRerank, "") require.NoError(t, err) assert.NotNil(t, result) } @@ -242,7 +242,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) { result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`), - }, "client", "provider", InterfaceTypeModels) + }, "client", "provider", InterfaceTypeModels, "") require.NoError(t, err) assert.NotNil(t, result) } @@ -259,7 +259,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) { result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`), - }, "client", "provider", InterfaceTypeModelInfo) + }, "client", "provider", InterfaceTypeModelInfo, "") require.NoError(t, err) assert.NotNil(t, result) } diff --git a/backend/internal/conversion/engine_test.go b/backend/internal/conversion/engine_test.go index 5c1b38e..37bba12 100644 --- a/backend/internal/conversion/engine_test.go +++ b/backend/internal/conversion/engine_test.go @@ -13,16 +13,18 @@ import ( // mockProtocolAdapter 模拟协议适配器 type mockProtocolAdapter struct { - protocolName string - passthrough bool - ifaceType InterfaceType - supportsIface map[InterfaceType]bool - decodeReqFn func([]byte) (*canonical.CanonicalRequest, error) - encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error) - decodeRespFn func([]byte) (*canonical.CanonicalResponse, error) - encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error) - streamDecoderFn func() StreamDecoder - streamEncoderFn func() StreamEncoder + protocolName string + passthrough bool + ifaceType InterfaceType + supportsIface map[InterfaceType]bool + decodeReqFn func([]byte) (*canonical.CanonicalRequest, error) + encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error) + decodeRespFn func([]byte) (*canonical.CanonicalResponse, error) + encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error) + streamDecoderFn func() StreamDecoder + streamEncoderFn func() StreamEncoder + rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error) + rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error) } func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter { @@ -155,6 +157,28 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera return json.Marshal(resp) } +func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) { + return "", nil +} + +func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) { + return "", nil +} + +func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) { + if m.rewriteReqFn != nil { + return m.rewriteReqFn(body, newModel, ifaceType) + } + return body, nil +} + +func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) { + if m.rewriteRespFn != nil { + return m.rewriteRespFn(body, newModel, ifaceType) + } + return body, nil +} + // noopStreamDecoder 空流式解码器 type noopStreamDecoder struct{} @@ -309,7 +333,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) { Body: []byte(`{"id":"123"}`), } - result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat) + result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "") require.NoError(t, err) assert.Equal(t, 200, result.StatusCode) assert.Equal(t, spec.Body, result.Body) @@ -320,7 +344,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) { engine := NewConversionEngine(registry, nil) _ = engine.RegisterAdapter(newMockAdapter("openai", true)) - converter, err := engine.CreateStreamConverter("openai", "openai") + converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat) require.NoError(t, err) _, ok := converter.(*PassthroughStreamConverter) assert.True(t, ok) @@ -332,7 +356,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) { _ = engine.RegisterAdapter(newMockAdapter("client", false)) _ = engine.RegisterAdapter(newMockAdapter("provider", false)) - converter, err := engine.CreateStreamConverter("client", "provider") + converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat) require.NoError(t, err) _, ok := converter.(*CanonicalStreamConverter) assert.True(t, ok) @@ -380,3 +404,230 @@ func TestRegistry_GetNonExistent(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "未找到适配器") } + +// ============ modelOverride 测试 ============ + +func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry, nil) + + clientAdapter := newMockAdapter("client", false) + clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) { + return json.Marshal(map[string]any{"model": resp.Model}) + } + _ = engine.RegisterAdapter(clientAdapter) + + providerAdapter := newMockAdapter("provider", false) + providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) { + return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil + } + _ = engine.RegisterAdapter(providerAdapter) + + spec := HTTPResponseSpec{ + StatusCode: 200, + Body: []byte(`{"model":"native-model"}`), + } + + result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4") + require.NoError(t, err) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(result.Body, &resp)) + assert.Equal(t, "provider/gpt-4", resp["model"]) +} + +func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry, nil) + + // 使用真实 OpenAI adapter 验证 Smart Passthrough 改写 + openaiAdapter := newMockAdapter("openai", true) + openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) { + var m map[string]json.RawMessage + if err := json.Unmarshal(body, &m); err != nil { + return nil, err + } + m["model"], _ = json.Marshal(newModel) + return json.Marshal(m) + } + _ = engine.RegisterAdapter(openaiAdapter) + + spec := HTTPResponseSpec{ + StatusCode: 200, + Body: []byte(`{"id":"resp-1","model":"gpt-4"}`), + } + + result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4") + require.NoError(t, err) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(result.Body, &resp)) + assert.Equal(t, "openai/gpt-4", resp["model"]) + assert.Equal(t, "resp-1", resp["id"]) +} + +func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry, nil) + + openaiAdapter := newMockAdapter("openai", true) + openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) { + var m map[string]json.RawMessage + if err := json.Unmarshal(body, &m); err != nil { + return nil, err + } + m["model"], _ = json.Marshal(newModel) + return json.Marshal(m) + } + _ = engine.RegisterAdapter(openaiAdapter) + + converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat) + require.NoError(t, err) + + _, ok := converter.(*SmartPassthroughStreamConverter) + assert.True(t, ok) + + // 验证 chunk 改写 + chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`)) + require.Len(t, chunks, 1) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(chunks[0], &resp)) + assert.Equal(t, "openai/gpt-4", resp["model"]) +} + +func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry, nil) + + // provider adapter 解码出含 model 的流式事件 + providerAdapter := newMockAdapter("provider", false) + providerAdapter.streamDecoderFn = func() StreamDecoder { + return &engineTestStreamDecoder{ + processFn: func(raw []byte) []canonical.CanonicalStreamEvent { + return []canonical.CanonicalStreamEvent{ + canonical.NewMessageStartEvent("msg-1", "native-model"), + canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}), + canonical.NewMessageStopEvent(), + } + }, + } + } + _ = engine.RegisterAdapter(providerAdapter) + + // client adapter 编码时输出 model 字段 + clientAdapter := newMockAdapter("client", false) + clientAdapter.streamEncoderFn = func() StreamEncoder { + return &engineTestStreamEncoder{ + encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte { + if event.Message != nil { + data, _ := json.Marshal(map[string]string{ + "type": string(event.Type), + "model": event.Message.Model, + }) + return [][]byte{data} + } + data, _ := json.Marshal(map[string]string{"type": string(event.Type)}) + return [][]byte{data} + }, + } + } + _ = engine.RegisterAdapter(clientAdapter) + + converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat) + require.NoError(t, err) + + // 验证类型是 CanonicalStreamConverter + _, ok := converter.(*CanonicalStreamConverter) + assert.True(t, ok) + + // 处理一个 chunk,验证 model 被覆写为统一模型 ID + chunks := converter.ProcessChunk([]byte("raw")) + require.Len(t, chunks, 3) // message_start + content_block_start + message_stop + + var startEvent map[string]string + require.NoError(t, json.Unmarshal(chunks[0], &startEvent)) + assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model") +} + +func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) { + registry := NewMemoryRegistry() + engine := NewConversionEngine(registry, nil) + + providerAdapter := newMockAdapter("provider", false) + providerAdapter.streamDecoderFn = func() StreamDecoder { + return &engineTestStreamDecoder{ + processFn: func(raw []byte) []canonical.CanonicalStreamEvent { + return []canonical.CanonicalStreamEvent{ + canonical.NewMessageStartEvent("msg-1", "native-model"), + } + }, + } + } + _ = engine.RegisterAdapter(providerAdapter) + + clientAdapter := newMockAdapter("client", false) + clientAdapter.streamEncoderFn = func() StreamEncoder { + return &engineTestStreamEncoder{ + encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte { + if event.Message != nil { + data, _ := json.Marshal(map[string]string{ + "model": event.Message.Model, + }) + return [][]byte{data} + } + return nil + }, + } + } + _ = engine.RegisterAdapter(clientAdapter) + + // modelOverride 为空,不应覆写 + converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat) + require.NoError(t, err) + + chunks := converter.ProcessChunk([]byte("raw")) + require.Len(t, chunks, 1) + + var resp map[string]string + require.NoError(t, json.Unmarshal(chunks[0], &resp)) + assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写") +} + +// engineTestStreamDecoder 可控的流式解码器(用于 engine_test) +type engineTestStreamDecoder struct { + processFn func([]byte) []canonical.CanonicalStreamEvent + flushFn func() []canonical.CanonicalStreamEvent +} + +func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent { + if d.processFn != nil { + return d.processFn(raw) + } + return nil +} +func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent { + if d.flushFn != nil { + return d.flushFn() + } + return nil +} + +// engineTestStreamEncoder 可控的流式编码器(用于 engine_test) +type engineTestStreamEncoder struct { + encodeFn func(canonical.CanonicalStreamEvent) [][]byte + flushFn func() [][]byte +} + +func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { + if e.encodeFn != nil { + return e.encodeFn(event) + } + return nil +} +func (e *engineTestStreamEncoder) Flush() [][]byte { + if e.flushFn != nil { + return e.flushFn() + } + return nil +} diff --git a/backend/internal/conversion/openai/adapter.go b/backend/internal/conversion/openai/adapter.go index 1f21509..78e57ac 100644 --- a/backend/internal/conversion/openai/adapter.go +++ b/backend/internal/conversion/openai/adapter.go @@ -2,6 +2,7 @@ package openai import ( "encoding/json" + "fmt" "strings" "nex/backend/internal/conversion" @@ -43,13 +44,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp } } -// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id}) +// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /) func isModelInfoPath(path string) bool { if !strings.HasPrefix(path, "/v1/models/") { return false } suffix := path[len("/v1/models/"):] - return suffix != "" && !strings.Contains(suffix, "/") + return suffix != "" } // BuildUrl 根据接口类型构建 URL @@ -216,3 +217,80 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { return encodeRerankResponse(resp) } + +// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name}) +func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) { + if !strings.HasPrefix(nativePath, "/v1/models/") { + return "", fmt.Errorf("不是模型详情路径: %s", nativePath) + } + suffix := nativePath[len("/v1/models/"):] + if suffix == "" { + return "", fmt.Errorf("路径缺少模型 ID") + } + return suffix, nil +} + +// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数 +func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) { + var m map[string]json.RawMessage + if err := json.Unmarshal(body, &m); err != nil { + return "", nil, err + } + + switch ifaceType { + case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank: + raw, exists := m["model"] + if !exists { + return "", nil, fmt.Errorf("请求体中缺少 model 字段") + } + var current string + if err := json.Unmarshal(raw, ¤t); err != nil { + return "", nil, err + } + rewriteFunc := func(newModel string) ([]byte, error) { + m["model"], _ = json.Marshal(newModel) + return json.Marshal(m) + } + return current, rewriteFunc, nil + default: + return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType) + } +} + +// ExtractModelName 从请求体中提取 model 值 +func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) { + model, _, err := locateModelFieldInRequest(body, ifaceType) + return model, err +} + +// RewriteRequestModelName 最小化改写请求体中的 model 字段 +func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) { + _, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType) + if err != nil { + return nil, err + } + return rewriteFunc(newModel) +} + +// RewriteResponseModelName 最小化改写响应体中的 model 字段 +func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) { + var m map[string]json.RawMessage + if err := json.Unmarshal(body, &m); err != nil { + return nil, err + } + + switch ifaceType { + case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings: + // Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加 + m["model"], _ = json.Marshal(newModel) + return json.Marshal(m) + case conversion.InterfaceTypeRerank: + // Rerank 响应:存在 model 字段则改写,不存在则不添加 + if _, exists := m["model"]; exists { + m["model"], _ = json.Marshal(newModel) + } + return json.Marshal(m) + default: + return body, nil + } +} diff --git a/backend/internal/conversion/openai/adapter_test.go b/backend/internal/conversion/openai/adapter_test.go index ef31e5b..e692b85 100644 --- a/backend/internal/conversion/openai/adapter_test.go +++ b/backend/internal/conversion/openai/adapter_test.go @@ -121,7 +121,7 @@ func TestIsModelInfoPath(t *testing.T) { {"model_info", "/v1/models/gpt-4", true}, {"model_info_with_dots", "/v1/models/gpt-4.1-preview", true}, {"models_list", "/v1/models", false}, - {"nested_path", "/v1/models/gpt-4/versions", false}, + {"nested_path", "/v1/models/gpt-4/versions", true}, {"empty_suffix", "/v1/models/", false}, {"unrelated", "/v1/chat/completions", false}, {"partial_prefix", "/v1/model", false}, diff --git a/backend/internal/conversion/openai/adapter_unified_test.go b/backend/internal/conversion/openai/adapter_unified_test.go new file mode 100644 index 0000000..5f50e00 --- /dev/null +++ b/backend/internal/conversion/openai/adapter_unified_test.go @@ -0,0 +1,360 @@ +package openai + +import ( + "encoding/json" + "testing" + + "nex/backend/internal/conversion" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ExtractUnifiedModelID +// --------------------------------------------------------------------------- + +func TestExtractUnifiedModelID(t *testing.T) { + a := NewAdapter() + + t.Run("standard_path", func(t *testing.T) { + id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4") + require.NoError(t, err) + assert.Equal(t, "openai/gpt-4", id) + }) + + t.Run("multi_segment_path", func(t *testing.T) { + id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4") + require.NoError(t, err) + assert.Equal(t, "azure/accounts/org/models/gpt-4", id) + }) + + t.Run("single_segment", func(t *testing.T) { + id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4") + require.NoError(t, err) + assert.Equal(t, "gpt-4", id) + }) + + t.Run("non_model_path", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/chat/completions") + require.Error(t, err) + }) + + t.Run("empty_suffix", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/models/") + require.Error(t, err) + }) + + t.Run("models_list_no_slash", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/models") + require.Error(t, err) + }) + + t.Run("unrelated_path", func(t *testing.T) { + _, err := a.ExtractUnifiedModelID("/v1/other") + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// ExtractModelName +// --------------------------------------------------------------------------- + +func TestExtractModelName(t *testing.T) { + a := NewAdapter() + + t.Run("chat", func(t *testing.T) { + body := []byte(`{"model":"openai/gpt-4","messages":[]}`) + model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, "openai/gpt-4", model) + }) + + t.Run("embedding", func(t *testing.T) { + body := []byte(`{"model":"openai/text-embedding","input":"hello"}`) + model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings) + require.NoError(t, err) + assert.Equal(t, "openai/text-embedding", model) + }) + + t.Run("rerank", func(t *testing.T) { + body := []byte(`{"model":"openai/rerank","query":"test"}`) + model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank) + require.NoError(t, err) + assert.Equal(t, "openai/rerank", model) + }) + + t.Run("no_model_field", func(t *testing.T) { + body := []byte(`{"messages":[]}`) + _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("invalid_json", func(t *testing.T) { + body := []byte(`{invalid}`) + _, err := a.ExtractModelName(body, conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("unsupported_interface_type", func(t *testing.T) { + body := []byte(`{"model":"openai/gpt-4"}`) + _, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough) + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// RewriteRequestModelName +// --------------------------------------------------------------------------- + +func TestRewriteRequestModelName(t *testing.T) { + a := NewAdapter() + + t.Run("chat", func(t *testing.T) { + body := []byte(`{"model":"openai/gpt-4","messages":[]}`) + rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "gpt-4", m["model"]) + + // messages field preserved + msgs, ok := m["messages"] + require.True(t, ok) + msgsArr, ok := msgs.([]interface{}) + require.True(t, ok) + assert.Len(t, msgsArr, 0) + }) + + t.Run("preserves_unknown_fields", func(t *testing.T) { + body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`) + rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "gpt-4", m["model"]) + assert.Equal(t, 0.7, m["temperature"]) + }) + + t.Run("embedding", func(t *testing.T) { + body := []byte(`{"model":"openai/text-embedding","input":"hello"}`) + rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "text-embedding", m["model"]) + assert.Equal(t, "hello", m["input"]) + }) + + t.Run("rerank", func(t *testing.T) { + body := []byte(`{"model":"openai/rerank","query":"test"}`) + rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "rerank", m["model"]) + assert.Equal(t, "test", m["query"]) + }) + + t.Run("no_model_field", func(t *testing.T) { + body := []byte(`{"messages":[]}`) + _, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("invalid_json", func(t *testing.T) { + body := []byte(`{invalid}`) + _, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat) + require.Error(t, err) + }) + + t.Run("unsupported_interface_type", func(t *testing.T) { + body := []byte(`{"model":"openai/gpt-4"}`) + _, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough) + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// RewriteResponseModelName +// --------------------------------------------------------------------------- + +func TestRewriteResponseModelName(t *testing.T) { + a := NewAdapter() + + t.Run("chat_existing_model", func(t *testing.T) { + body := []byte(`{"model":"gpt-4","choices":[]}`) + rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "openai/gpt-4", m["model"]) + + choices, ok := m["choices"] + require.True(t, ok) + choicesArr, ok := choices.([]interface{}) + require.True(t, ok) + assert.Len(t, choicesArr, 0) + }) + + t.Run("chat_without_model_field", func(t *testing.T) { + body := []byte(`{"choices":[]}`) + rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "openai/gpt-4", m["model"]) + + choices, ok := m["choices"] + require.True(t, ok) + choicesArr, ok := choices.([]interface{}) + require.True(t, ok) + assert.Len(t, choicesArr, 0) + }) + + t.Run("rerank_existing_model", func(t *testing.T) { + body := []byte(`{"model":"rerank","results":[]}`) + rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "openai/rerank", m["model"]) + }) + + t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) { + body := []byte(`{"results":[]}`) + rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + _, hasModel := m["model"] + assert.False(t, hasModel, "rerank response without model field should not have one added") + }) + + t.Run("embedding_existing_model", func(t *testing.T) { + body := []byte(`{"model":"text-embedding","data":[]}`) + rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "openai/text-embedding", m["model"]) + }) + + t.Run("embedding_without_model_field_adds", func(t *testing.T) { + body := []byte(`{"data":[]}`) + rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings) + require.NoError(t, err) + + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, "openai/text-embedding", m["model"]) + }) + + t.Run("passthrough_returns_body_unchanged", func(t *testing.T) { + body := []byte(`{"model":"gpt-4"}`) + rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough) + require.NoError(t, err) + assert.Equal(t, string(body), string(rewritten)) + }) + + t.Run("invalid_json", func(t *testing.T) { + body := []byte(`{invalid}`) + _, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat) + require.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// ExtractModelName and RewriteRequest consistency +// --------------------------------------------------------------------------- + +func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) { + a := NewAdapter() + + t.Run("chat_round_trip", func(t *testing.T) { + original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`) + + // Extract the unified model ID from the body + extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, "openai/gpt-4", extracted) + + // Rewrite to the native model name + rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat) + require.NoError(t, err) + + // Extract again from the rewritten body to verify the same location was targeted + afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat) + require.NoError(t, err) + assert.Equal(t, "gpt-4", afterRewrite) + + // Verify other fields are preserved + var m map[string]interface{} + require.NoError(t, json.Unmarshal(rewritten, &m)) + assert.Equal(t, 0.7, m["temperature"]) + }) + + t.Run("embedding_round_trip", func(t *testing.T) { + original := []byte(`{"model":"openai/text-embedding","input":"hello"}`) + + extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings) + require.NoError(t, err) + assert.Equal(t, "openai/text-embedding", extracted) + + rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings) + require.NoError(t, err) + + afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings) + require.NoError(t, err) + assert.Equal(t, "text-embedding", afterRewrite) + }) + + t.Run("rerank_round_trip", func(t *testing.T) { + original := []byte(`{"model":"openai/rerank","query":"test"}`) + + extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank) + require.NoError(t, err) + assert.Equal(t, "openai/rerank", extracted) + + rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank) + require.NoError(t, err) + + afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank) + require.NoError(t, err) + assert.Equal(t, "rerank", afterRewrite) + }) +} + +// --------------------------------------------------------------------------- +// isModelInfoPath (additional unified model ID cases) +// --------------------------------------------------------------------------- + +func TestIsModelInfoPath_UnifiedModelID(t *testing.T) { + tests := []struct { + name string + path string + expected bool + }{ + {"simple_model_id", "/v1/models/gpt-4", true}, + {"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true}, + {"models_list", "/v1/models", false}, + {"models_list_trailing_slash", "/v1/models/", false}, + {"chat_completions", "/v1/chat/completions", false}, + {"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, isModelInfoPath(tt.path)) + }) + } +} diff --git a/backend/internal/conversion/stream.go b/backend/internal/conversion/stream.go index b4def9f..98cd4e0 100644 --- a/backend/internal/conversion/stream.go +++ b/backend/internal/conversion/stream.go @@ -38,14 +38,52 @@ func (c *PassthroughStreamConverter) Flush() [][]byte { return nil } +// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器 +// 逐 chunk 改写 model 字段 +type SmartPassthroughStreamConverter struct { + adapter ProtocolAdapter + modelOverride string + interfaceType InterfaceType +} + +// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器 +func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter { + return &SmartPassthroughStreamConverter{ + adapter: adapter, + modelOverride: modelOverride, + interfaceType: interfaceType, + } +} + +// ProcessChunk 改写 chunk 中的 model 字段 +func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { + if len(rawChunk) == 0 { + return nil + } + + rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType) + if err != nil { + // 改写失败,返回原始 chunk + return [][]byte{rawChunk} + } + + return [][]byte{rewrittenChunk} +} + +// Flush 无缓冲数据 +func (c *SmartPassthroughStreamConverter) Flush() [][]byte { + return nil +} + // CanonicalStreamConverter 跨协议规范流式转换器 type CanonicalStreamConverter struct { - decoder StreamDecoder - encoder StreamEncoder - chain *MiddlewareChain - ctx ConversionContext - clientProtocol string + decoder StreamDecoder + encoder StreamEncoder + chain *MiddlewareChain + ctx ConversionContext + clientProtocol string providerProtocol string + modelOverride string } // NewCanonicalStreamConverter 创建规范流式转换器 @@ -57,18 +95,19 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) * } // NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器 -func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter { +func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter { return &CanonicalStreamConverter{ - decoder: decoder, - encoder: encoder, - chain: chain, - ctx: ctx, - clientProtocol: clientProtocol, + decoder: decoder, + encoder: encoder, + chain: chain, + ctx: ctx, + clientProtocol: clientProtocol, providerProtocol: providerProtocol, + modelOverride: modelOverride, } } -// ProcessChunk 解码 → 中间件 → 编码管道 +// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道 func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { events := c.decoder.ProcessChunk(rawChunk) var result [][]byte @@ -80,6 +119,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { } events[i] = *processed } + c.applyModelOverride(&events[i]) chunks := c.encoder.EncodeEvent(events[i]) result = append(result, chunks...) } @@ -98,6 +138,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte { } events[i] = *processed } + c.applyModelOverride(&events[i]) chunks := c.encoder.EncodeEvent(events[i]) result = append(result, chunks...) } @@ -105,3 +146,10 @@ func (c *CanonicalStreamConverter) Flush() [][]byte { result = append(result, encoderChunks...) return result } + +// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段 +func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) { + if c.modelOverride != "" && event.Message != nil { + event.Message.Model = c.modelOverride + } +} diff --git a/backend/internal/conversion/stream_test.go b/backend/internal/conversion/stream_test.go index c7f111d..516a628 100644 --- a/backend/internal/conversion/stream_test.go +++ b/backend/internal/conversion/stream_test.go @@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) { chain.Use(&recordingMiddleware{name: "mw1", records: &records}) ctx := NewConversionContext(InterfaceTypeChat) - converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic") + converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "") result := converter.ProcessChunk([]byte("raw")) assert.Len(t, result, 1) @@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) { chain.Use(&errorMiddleware{}) ctx := NewConversionContext(InterfaceTypeChat) - converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic") + converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "") result := converter.ProcessChunk([]byte("raw")) assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)") @@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) { chain.Use(&errorMiddleware{}) ctx := NewConversionContext(InterfaceTypeChat) - converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic") + converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "") result := converter.Flush() assert.Len(t, result, 1) diff --git a/backend/internal/domain/model.go b/backend/internal/domain/model.go index 30b4b49..60e2684 100644 --- a/backend/internal/domain/model.go +++ b/backend/internal/domain/model.go @@ -1,8 +1,12 @@ package domain -import "time" +import ( + "time" -// Model 模型领域模型 + "nex/backend/pkg/modelid" +) + +// Model 模型领域模型(id 为 UUID 自动生成) type Model struct { ID string `json:"id"` ProviderID string `json:"provider_id"` @@ -10,3 +14,8 @@ type Model struct { Enabled bool `json:"enabled"` CreatedAt time.Time `json:"created_at"` } + +// UnifiedModelID 返回统一模型 ID(格式:provider_id/model_name) +func (m *Model) UnifiedModelID() string { + return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName) +} diff --git a/backend/internal/handler/handler_supplemental_test.go b/backend/internal/handler/handler_supplemental_test.go index 82a43a5..716c406 100644 --- a/backend/internal/handler/handler_supplemental_test.go +++ b/backend/internal/handler/handler_supplemental_test.go @@ -113,7 +113,6 @@ func TestModelHandler_CreateModel_Success(t *testing.T) { h := NewModelHandler(&mockModelService{}) body, _ := json.Marshal(map[string]string{ - "id": "m1", "provider_id": "p1", "model_name": "gpt-4", }) @@ -127,7 +126,7 @@ func TestModelHandler_CreateModel_Success(t *testing.T) { var result domain.Model require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) - assert.Equal(t, "m1", result.ID) + assert.NotEmpty(t, result.ID) } func TestModelHandler_GetModel(t *testing.T) { diff --git a/backend/internal/handler/handler_test.go b/backend/internal/handler/handler_test.go index 89db9ac..74e0bb5 100644 --- a/backend/internal/handler/handler_test.go +++ b/backend/internal/handler/handler_test.go @@ -13,7 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gorm.io/gorm" "nex/backend/internal/domain" "nex/backend/internal/provider" @@ -31,7 +30,7 @@ type mockRoutingService struct { err error } -func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) { +func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) { return m.result, m.err } @@ -57,6 +56,14 @@ type mockProviderService struct { err error } +func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) { + return nil, nil +} + +func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) { + return nil, nil +} + func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err } func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) { return m.provider, m.err @@ -73,13 +80,21 @@ type mockModelService struct { err error } -func (m *mockModelService) Create(model *domain.Model) error { return m.err } +func (m *mockModelService) Create(model *domain.Model) error { + if m.err == nil { + model.ID = "mock-uuid-1234" + } + return m.err +} func (m *mockModelService) Get(id string) (*domain.Model, error) { return m.model, m.err } func (m *mockModelService) List(providerID string) ([]domain.Model, error) { return m.models, m.err } +func (m *mockModelService) ListEnabled() ([]domain.Model, error) { + return []domain.Model{}, nil +} func (m *mockModelService) Update(id string, updates map[string]interface{}) error { return m.err } @@ -163,8 +178,8 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) { func TestModelHandler_ListModels(t *testing.T) { h := NewModelHandler(&mockModelService{ models: []domain.Model{ - {ID: "m1", ModelName: "gpt-4"}, - {ID: "m2", ModelName: "gpt-3.5"}, + {ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, + {ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"}, }, }) @@ -174,6 +189,72 @@ func TestModelHandler_ListModels(t *testing.T) { h.ListModels(c) assert.Equal(t, 200, w.Code) + + var result []modelResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) + require.Len(t, result, 2) + assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID) + assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID) +} + +func TestModelHandler_GetModel_UnifiedID(t *testing.T) { + h := NewModelHandler(&mockModelService{ + model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, + }) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "m1"}} + c.Request = httptest.NewRequest("GET", "/api/models/m1", nil) + + h.GetModel(c) + assert.Equal(t, 200, w.Code) + + var result modelResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) + assert.Equal(t, "m1", result.ID) + assert.Equal(t, "openai/gpt-4", result.UnifiedModelID) +} + +func TestModelHandler_CreateModel_UnifiedID(t *testing.T) { + h := NewModelHandler(&mockModelService{}) + + body, _ := json.Marshal(map[string]string{ + "provider_id": "openai", + "model_name": "gpt-4", + }) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.CreateModel(c) + assert.Equal(t, 201, w.Code) + + var result modelResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) + assert.Equal(t, "mock-uuid-1234", result.ID) + assert.Equal(t, "openai/gpt-4", result.UnifiedModelID) +} + +func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) { + h := NewModelHandler(&mockModelService{ + model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, + }) + + body, _ := json.Marshal(map[string]interface{}{"enabled": false}) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "m1"}} + c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.UpdateModel(c) + assert.Equal(t, 200, w.Code) + + var result modelResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) + assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID) } // ============ Stats Handler 测试 ============ @@ -256,7 +337,7 @@ func formatMapErrors(errs map[string]string) string { func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) { h := NewProviderHandler(&mockProviderService{ - err: gorm.ErrDuplicatedKey, + err: appErrors.ErrConflict, }) body, _ := json.Marshal(map[string]string{ diff --git a/backend/internal/handler/model_handler.go b/backend/internal/handler/model_handler.go index 9135c66..5432f55 100644 --- a/backend/internal/handler/model_handler.go +++ b/backend/internal/handler/model_handler.go @@ -1,6 +1,7 @@ package handler import ( + "errors" "net/http" "github.com/gin-gonic/gin" @@ -22,23 +23,35 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler { return &ModelHandler{modelService: modelService} } +// modelResponse 模型响应 DTO,扩展 unified_id 字段 +type modelResponse struct { + domain.Model + UnifiedModelID string `json:"unified_id"` +} + +// newModelResponse 从 domain.Model 构造响应 DTO +func newModelResponse(m *domain.Model) modelResponse { + return modelResponse{ + Model: *m, + UnifiedModelID: m.UnifiedModelID(), + } +} + // CreateModel 创建模型 func (h *ModelHandler) CreateModel(c *gin.Context) { var req struct { - ID string `json:"id" binding:"required"` ProviderID string `json:"provider_id" binding:"required"` ModelName string `json:"model_name" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ - "error": "缺少必需字段: id, provider_id, model_name", + "error": "缺少必需字段: provider_id, model_name", }) return } model := &domain.Model{ - ID: req.ID, ProviderID: req.ProviderID, ModelName: req.ModelName, } @@ -51,11 +64,18 @@ func (h *ModelHandler) CreateModel(c *gin.Context) { }) return } + if err == appErrors.ErrDuplicateModel { + c.JSON(http.StatusConflict, gin.H{ + "error": "同一供应商下模型名称已存在", + "code": appErrors.ErrDuplicateModel.Code, + }) + return + } writeError(c, err) return } - c.JSON(http.StatusCreated, model) + c.JSON(http.StatusCreated, newModelResponse(model)) } // ListModels 列出模型 @@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) { return } - c.JSON(http.StatusOK, models) + resp := make([]modelResponse, len(models)) + for i, m := range models { + resp[i] = newModelResponse(&m) + } + c.JSON(http.StatusOK, resp) } // GetModel 获取模型 @@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) { return } - c.JSON(http.StatusOK, model) + c.JSON(http.StatusOK, newModelResponse(model)) } // UpdateModel 更新模型 @@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) { err := h.modelService.Update(id, req) if err != nil { - if err == gorm.ErrRecordNotFound { + if errors.Is(err, appErrors.ErrModelNotFound) { c.JSON(http.StatusNotFound, gin.H{ "error": "模型未找到", }) return } - if err == appErrors.ErrProviderNotFound { + if errors.Is(err, appErrors.ErrProviderNotFound) { c.JSON(http.StatusBadRequest, gin.H{ "error": "供应商不存在", }) return } + if errors.Is(err, appErrors.ErrDuplicateModel) { + c.JSON(http.StatusConflict, gin.H{ + "error": appErrors.ErrDuplicateModel.Message, + "code": appErrors.ErrDuplicateModel.Code, + }) + return + } writeError(c, err) return } @@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) { return } - c.JSON(http.StatusOK, model) + c.JSON(http.StatusOK, newModelResponse(model)) } // DeleteModel 删除模型 diff --git a/backend/internal/handler/provider_handler.go b/backend/internal/handler/provider_handler.go index 34d8f36..db08a32 100644 --- a/backend/internal/handler/provider_handler.go +++ b/backend/internal/handler/provider_handler.go @@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) { err := h.providerService.Create(provider) if err != nil { - if errors.Is(err, gorm.ErrDuplicatedKey) { - c.JSON(http.StatusConflict, gin.H{ - "error": "供应商 ID 已存在", + if err == appErrors.ErrInvalidProviderID { + c.JSON(http.StatusBadRequest, gin.H{ + "error": appErrors.ErrInvalidProviderID.Message, + "code": appErrors.ErrInvalidProviderID.Code, }) return } @@ -119,6 +120,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) { }) return } + if errors.Is(err, appErrors.ErrImmutableField) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": appErrors.ErrImmutableField.Message, + "code": appErrors.ErrImmutableField.Code, + }) + return + } writeError(c, err) return } diff --git a/backend/internal/handler/proxy_handler.go b/backend/internal/handler/proxy_handler.go index 03923e8..7b96c89 100644 --- a/backend/internal/handler/proxy_handler.go +++ b/backend/internal/handler/proxy_handler.go @@ -11,9 +11,11 @@ import ( "go.uber.org/zap" "nex/backend/internal/conversion" + "nex/backend/internal/conversion/canonical" "nex/backend/internal/domain" "nex/backend/internal/provider" "nex/backend/internal/service" + "nex/backend/pkg/modelid" ) // ProxyHandler 统一代理处理器 @@ -54,6 +56,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) { } nativePath := "/v1/" + path + // 获取 client adapter + registry := h.engine.GetRegistry() + clientAdapter, err := registry.Get(clientProtocol) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) + return + } + + // 检测接口类型 + ifaceType := clientAdapter.DetectInterfaceType(nativePath) + + // 处理 Models 接口:本地聚合 + if ifaceType == conversion.InterfaceTypeModels { + h.handleModelsList(c, clientAdapter) + return + } + + // 处理 ModelInfo 接口:本地查询 + if ifaceType == conversion.InterfaceTypeModelInfo { + unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"}) + return + } + h.handleModelInfo(c, unifiedID, clientAdapter) + return + } + // 读取请求体 body, err := io.ReadAll(c.Request.Body) if err != nil { @@ -61,10 +91,17 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) { return } - // 解析 model 名称(从 JSON body 中提取,GET 请求无 body) - modelName := "" + // 解析统一模型 ID(使用 adapter.ExtractModelName) + var providerID, modelName string if len(body) > 0 { - modelName = extractModelName(body) + unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType) + if err == nil && unifiedID != "" { + pid, mn, err := modelid.ParseUnifiedModelID(unifiedID) + if err == nil { + providerID = pid + modelName = mn + } + } } // 构建输入 HTTPRequestSpec @@ -76,7 +113,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) { } // 路由 - routeResult, err := h.routingService.Route(modelName) + routeResult, err := h.routingService.RouteByModelName(providerID, modelName) if err != nil { // GET 请求或无法提取 model 时,直接转发到上游 if len(body) == 0 || modelName == "" { @@ -94,24 +131,30 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) { } // 构建 TargetProvider + // 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体 + // 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名) + // 跨协议:全量转换时 ModelName 会被编码到请求体中 targetProvider := conversion.NewTargetProvider( routeResult.Provider.BaseURL, routeResult.Provider.APIKey, - routeResult.Model.ModelName, + routeResult.Model.ModelName, // 上游模型名,用于请求改写 ) // 判断是否流式 isStream := h.isStreamRequest(body, clientProtocol, nativePath) + // 计算统一模型 ID(用于响应覆写) + unifiedModelID := routeResult.Model.UnifiedModelID() + if isStream { - h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) + h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType) } else { - h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) + h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType) } } // handleNonStream 处理非流式请求 -func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { +func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { @@ -128,9 +171,8 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq return } - // 转换响应 - interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol) - convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType) + // 转换响应,传入 modelOverride(跨协议场景覆写 model 字段) + convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) if err != nil { h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.writeConversionError(c, err, clientProtocol) @@ -153,7 +195,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq } // handleStream 处理流式请求 -func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { +func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { // 转换请求 outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) if err != nil { @@ -161,8 +203,8 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques return } - // 创建流式转换器 - streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol) + // 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段) + streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType) if err != nil { h.writeConversionError(c, err, clientProtocol) return @@ -224,6 +266,79 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s return req.Stream } +// handleModelsList 处理 GET /v1/models 本地聚合 +func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) { + // 从数据库查询所有启用的模型 + models, err := h.providerService.ListEnabledModels() + if err != nil { + h.logger.Error("查询启用模型失败", zap.String("error", err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"}) + return + } + + // 构建 CanonicalModelList + modelList := &canonical.CanonicalModelList{ + Models: make([]canonical.CanonicalModel, 0, len(models)), + } + + for _, m := range models { + modelList.Models = append(modelList.Models, canonical.CanonicalModel{ + ID: m.UnifiedModelID(), + Name: m.ModelName, + Created: m.CreatedAt.Unix(), + OwnedBy: m.ProviderID, + }) + } + + // 使用 adapter 编码返回 + body, err := adapter.EncodeModelsResponse(modelList) + if err != nil { + h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) + return + } + + c.Data(http.StatusOK, "application/json", body) +} + +// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询 +func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) { + // 解析统一模型 ID + providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "无效的统一模型 ID 格式", + "code": "INVALID_MODEL_ID", + }) + return + } + + // 从数据库查询模型 + model, err := h.providerService.GetModelByProviderAndName(providerID, modelName) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"}) + return + } + + // 构建 CanonicalModelInfo + modelInfo := &canonical.CanonicalModelInfo{ + ID: model.UnifiedModelID(), + Name: model.ModelName, + Created: model.CreatedAt.Unix(), + OwnedBy: model.ProviderID, + } + + // 使用 adapter 编码返回 + body, err := adapter.EncodeModelInfoResponse(modelInfo) + if err != nil { + h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error())) + c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) + return + } + + c.Data(http.StatusOK, "application/json", body) +} + // writeConversionError 写入转换错误 func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { if convErr, ok := err.(*conversion.ConversionError); ok { @@ -292,7 +407,7 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP return } - convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType) + convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "") if err != nil { h.writeConversionError(c, err, clientProtocol) return @@ -307,17 +422,6 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) } -// extractModelName 从 JSON body 中提取 model -func extractModelName(body []byte) string { - var req struct { - Model string `json:"model"` - } - if err := json.Unmarshal(body, &req); err != nil { - return "" - } - return req.Model -} - // extractHeaders 从 Gin context 提取请求头 func extractHeaders(c *gin.Context) map[string]string { headers := make(map[string]string) diff --git a/backend/internal/handler/proxy_handler_test.go b/backend/internal/handler/proxy_handler_test.go index ae29b0b..d0ee800 100644 --- a/backend/internal/handler/proxy_handler_test.go +++ b/backend/internal/handler/proxy_handler_test.go @@ -60,13 +60,23 @@ type mockProxyRoutingService struct { err error } -func (m *mockProxyRoutingService) Route(modelName string) (*domain.RouteResult, error) { +func (m *mockProxyRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) { return m.result, m.err } type mockProxyProviderService struct { - providers []domain.Provider - err error + providers []domain.Provider + err error + enabledModels []domain.Model + modelByProvName *domain.Model +} + +func (m *mockProxyProviderService) ListEnabledModels() ([]domain.Model, error) { + return m.enabledModels, nil +} + +func (m *mockProxyProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) { + return m.modelByProvName, nil } func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil } @@ -319,7 +329,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) { c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil) h.HandleProxy(c) - assert.Equal(t, 404, w.Code) + // Models 接口现在本地聚合,返回空列表 200 + assert.Equal(t, 200, w.Code) } func TestExtractHeaders(t *testing.T) { @@ -716,58 +727,6 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) { assert.Equal(t, 200, w.Code) } -// ============ extractModelName 测试 ============ - -func TestExtractModelName(t *testing.T) { - tests := []struct { - name string - body []byte - expected string - }{ - { - name: "valid model", - body: []byte(`{"model": "gpt-4", "messages": []}`), - expected: "gpt-4", - }, - { - name: "empty body", - body: []byte(`{}`), - expected: "", - }, - { - name: "invalid json", - body: []byte(`{invalid}`), - expected: "", - }, - { - name: "nested structure", - body: []byte(`{"model": "claude-3", "messages": [{"role": "user", "content": "hello"}]}`), - expected: "claude-3", - }, - { - name: "model with special chars", - body: []byte(`{"model": "gpt-4-0125-preview", "stream": true}`), - expected: "gpt-4-0125-preview", - }, - { - name: "empty body bytes", - body: []byte{}, - expected: "", - }, - { - name: "model is null", - body: []byte(`{"model": null}`), - expected: "", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := extractModelName(tt.body) - assert.Equal(t, tt.expected, result) - }) - } -} - // ============ isStreamRequest 测试 ============ func TestIsStreamRequest(t *testing.T) { @@ -831,3 +790,270 @@ func TestIsStreamRequest(t *testing.T) { }) } } + +// ============ Models / ModelInfo 本地聚合测试 ============ + +func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) { + engine := setupProxyEngine(t) + providerSvc := &mockProxyProviderService{ + enabledModels: []domain.Model{ + {ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, + {ID: "m2", ProviderID: "anthropic", ModelName: "claude-3", Enabled: true}, + }, + } + h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} + c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil) + + h.HandleProxy(c) + assert.Equal(t, 200, w.Code) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + data, ok := resp["data"].([]interface{}) + require.True(t, ok) + assert.Len(t, data, 2) + + // 验证统一模型 ID 格式 + first := data[0].(map[string]interface{}) + assert.Equal(t, "openai/gpt-4", first["id"]) +} + +func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) { + engine := setupProxyEngine(t) + providerSvc := &mockProxyProviderService{ + modelByProvName: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}, + } + h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}} + c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil) + + h.HandleProxy(c) + assert.Equal(t, 200, w.Code) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(t, "openai/gpt-4", resp["id"]) +} + +func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testing.T) { + engine := setupProxyEngine(t) + providerSvc := &mockProxyProviderService{ + providers: []domain.Provider{ + {ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"}, + }, + } + client := &mockProxyProviderClient{ + sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { + return &conversion.HTTPResponseSpec{ + StatusCode: 200, + Body: []byte(`{"object":"list","data":[]}`), + }, nil + }, + } + h := newTestProxyHandler(engine, client, &mockProxyRoutingService{err: appErrors.ErrModelNotFound}, providerSvc) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}} + c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil) + + h.HandleProxy(c) + assert.Equal(t, 200, w.Code) +} + +// ============ Smart Passthrough 统一模型 ID 路由测试 ============ + +func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) { + engine := setupProxyEngine(t) + routingSvc := &mockProxyRoutingService{ + result: &domain.RouteResult{ + Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, + Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true}, + }, + } + client := &mockProxyProviderClient{ + sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { + // 验证请求体中的 model 已被改写为上游模型名 + var req map[string]interface{} + json.Unmarshal(spec.Body, &req) + assert.Equal(t, "gpt-4", req["model"]) + + return &conversion.HTTPResponseSpec{ + StatusCode: 200, + Headers: map[string]string{"Content-Type": "application/json"}, + Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}`), + }, nil + }, + } + h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{}) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} + // 客户端发送统一模型 ID + c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`))) + + h.HandleProxy(c) + assert.Equal(t, 200, w.Code) + + // 验证响应中的 model 已被改写为统一模型 ID + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(t, "openai_p/gpt-4", resp["model"]) +} + +// ============ 跨协议统一模型 ID 路由测试 ============ + +func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T) { + engine := setupProxyEngine(t) + routingSvc := &mockProxyRoutingService{ + result: &domain.RouteResult{ + Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true}, + Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true}, + }, + } + client := &mockProxyProviderClient{ + sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { + return &conversion.HTTPResponseSpec{ + StatusCode: 200, + Headers: map[string]string{"Content-Type": "application/json"}, + Body: []byte(`{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[{"type":"text","text":"Hello"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":3}}`), + }, nil + }, + } + h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{}) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} + // OpenAI 客户端使用统一模型 ID 路由到 Anthropic 供应商 + c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`))) + + h.HandleProxy(c) + assert.Equal(t, 200, w.Code) + + // 验证跨协议转换后响应中的 model 被覆写为统一模型 ID + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(t, "anthropic_p/claude-3", resp["model"]) +} + +func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) { + engine := setupProxyEngine(t) + routingSvc := &mockProxyRoutingService{ + result: &domain.RouteResult{ + Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true}, + Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true}, + }, + } + client := &mockProxyProviderClient{ + sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { + ch := make(chan provider.StreamEvent, 10) + go func() { + defer close(ch) + ch <- provider.StreamEvent{Data: []byte(`event: message_start +data: {"type":"message_start","message":{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[]}} + +`)} + ch <- provider.StreamEvent{Data: []byte(`event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}} + +`)} + ch <- provider.StreamEvent{Data: []byte(`event: message_stop +data: {"type":"message_stop"} + +`)} + ch <- provider.StreamEvent{Done: true} + }() + return ch, nil + }, + } + h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{}) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} + c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`))) + + h.HandleProxy(c) + assert.Equal(t, 200, w.Code) + assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) + + body := w.Body.String() + // 验证跨协议流式中 model 被覆写为统一模型 ID + assert.Contains(t, body, "anthropic_p/claude-3", "跨协议流式响应中 model 应被覆写为统一模型 ID") +} + +func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) { + engine := setupProxyEngine(t) + routingSvc := &mockProxyRoutingService{ + result: &domain.RouteResult{ + Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, + Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true}, + }, + } + var capturedRequestBody []byte + client := &mockProxyProviderClient{ + sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { + capturedRequestBody = spec.Body + return &conversion.HTTPResponseSpec{ + StatusCode: 200, + Headers: map[string]string{"Content-Type": "application/json"}, + Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"unknown_field":"preserved"}`), + }, nil + }, + } + h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{}) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} + // 包含未知参数,验证 Smart Passthrough 保真性 + c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`))) + + h.HandleProxy(c) + assert.Equal(t, 200, w.Code) + + // 验证请求中 model 被改写为上游模型名,但未知参数保留 + var reqBody map[string]interface{} + require.NoError(t, json.Unmarshal(capturedRequestBody, &reqBody)) + assert.Equal(t, "gpt-4", reqBody["model"], "请求中 model 应被改写为上游模型名") + assert.Equal(t, "should_be_preserved", reqBody["custom_param"], "Smart Passthrough 应保留未知参数") + + // 验证响应中 model 被改写为统一模型 ID,但未知参数保留 + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(t, "openai_p/gpt-4", resp["model"], "响应中 model 应被改写为统一模型 ID") + assert.Equal(t, "preserved", resp["unknown_field"], "Smart Passthrough 应保留未知响应字段") +} + +func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) { + engine := setupProxyEngine(t) + routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound} + providerSvc := &mockProxyProviderService{ + providers: []domain.Provider{ + {ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"}, + }, + } + h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, providerSvc) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} + // 使用统一模型 ID 格式但模型不存在 + c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`))) + + h.HandleProxy(c) + assert.Equal(t, 404, w.Code) + + var resp map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Contains(t, resp, "error") +} diff --git a/backend/internal/repository/model_repo.go b/backend/internal/repository/model_repo.go index daeb76f..929e0bd 100644 --- a/backend/internal/repository/model_repo.go +++ b/backend/internal/repository/model_repo.go @@ -7,7 +7,8 @@ type ModelRepository interface { Create(model *domain.Model) error GetByID(id string) (*domain.Model, error) List(providerID string) ([]domain.Model, error) - GetByModelName(modelName string) (*domain.Model, error) + FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) + ListEnabled() ([]domain.Model, error) Update(id string, updates map[string]interface{}) error Delete(id string) error } diff --git a/backend/internal/repository/model_repo_impl.go b/backend/internal/repository/model_repo_impl.go index 31a0e14..c99a579 100644 --- a/backend/internal/repository/model_repo_impl.go +++ b/backend/internal/repository/model_repo_impl.go @@ -52,9 +52,9 @@ func (r *modelRepository) List(providerID string) ([]domain.Model, error) { return result, nil } -func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) { +func (r *modelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) { var m config.Model - err := r.db.Where("model_name = ?", modelName).First(&m).Error + err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&m).Error if err != nil { return nil, err } @@ -62,6 +62,21 @@ func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error return &d, nil } +func (r *modelRepository) ListEnabled() ([]domain.Model, error) { + var models []config.Model + err := r.db.Joins("JOIN providers ON providers.id = models.provider_id"). + Where("models.enabled = ? AND providers.enabled = ?", true, true). + Find(&models).Error + if err != nil { + return nil, err + } + result := make([]domain.Model, len(models)) + for i := range models { + result[i] = toDomainModel(&models[i]) + } + return result, nil +} + func (r *modelRepository) Update(id string, updates map[string]interface{}) error { result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates) if result.Error != nil { diff --git a/backend/internal/repository/provider_repo.go b/backend/internal/repository/provider_repo.go index 18986dc..263d7a7 100644 --- a/backend/internal/repository/provider_repo.go +++ b/backend/internal/repository/provider_repo.go @@ -9,4 +9,7 @@ type ProviderRepository interface { List() ([]domain.Provider, error) Update(id string, updates map[string]interface{}) error Delete(id string) error + // 统一模型 ID 相关方法 + ListEnabledModels() ([]domain.Model, error) + FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) } diff --git a/backend/internal/repository/provider_repo_impl.go b/backend/internal/repository/provider_repo_impl.go index 6ea917b..d59d2a7 100644 --- a/backend/internal/repository/provider_repo_impl.go +++ b/backend/internal/repository/provider_repo_impl.go @@ -71,6 +71,25 @@ func (r *providerRepository) Delete(id string) error { return nil } +// ListEnabledModels 返回所有启用的模型(关联启用的供应商) +func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) { + var models []domain.Model + err := r.db.Joins("JOIN providers ON providers.id = models.provider_id"). + Where("models.enabled = ? AND providers.enabled = ?", true, true). + Find(&models).Error + return models, err +} + +// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型 +func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) { + var model domain.Model + err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error + if err != nil { + return nil, err + } + return &model, nil +} + func toDomainProvider(p *config.Provider) domain.Provider { return domain.Provider{ ID: p.ID, diff --git a/backend/internal/repository/repository_test.go b/backend/internal/repository/repository_test.go index 584b15b..a0f7064 100644 --- a/backend/internal/repository/repository_test.go +++ b/backend/internal/repository/repository_test.go @@ -147,15 +147,36 @@ func TestModelRepository_GetByID(t *testing.T) { assert.Equal(t, "gpt-4", result.ModelName) } -func TestModelRepository_GetByModelName(t *testing.T) { +func TestModelRepository_FindByProviderAndModelName(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) - result, err := repo.GetByModelName("gpt-4") + result, err := repo.FindByProviderAndModelName("p1", "gpt-4") require.NoError(t, err) assert.Equal(t, "m1", result.ID) + assert.Equal(t, "p1", result.ProviderID) + assert.Equal(t, "gpt-4", result.ModelName) +} + +func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) { + db := setupTestDB(t) + repo := NewModelRepository(db) + + repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) + + // Wrong provider_id + _, err := repo.FindByProviderAndModelName("p2", "gpt-4") + assert.Error(t, err) + + // Wrong model_name + _, err = repo.FindByProviderAndModelName("p1", "gpt-3.5") + assert.Error(t, err) + + // Both wrong + _, err = repo.FindByProviderAndModelName("p2", "claude-3") + assert.Error(t, err) } func TestModelRepository_List(t *testing.T) { @@ -175,6 +196,54 @@ func TestModelRepository_List(t *testing.T) { assert.Len(t, p1Models, 2) } +func TestModelRepository_ListEnabled(t *testing.T) { + db := setupTestDB(t) + providerRepo := NewProviderRepository(db) + modelRepo := NewModelRepository(db) + + // Create two providers (both start enabled due to gorm:"default:true") + err := providerRepo.Create(&domain.Provider{ + ID: "enabled-provider", Name: "Enabled Provider", + APIKey: "key1", BaseURL: "https://enabled.com", Enabled: true, + }) + require.NoError(t, err) + err = providerRepo.Create(&domain.Provider{ + ID: "disabled-provider", Name: "Disabled Provider", + APIKey: "key2", BaseURL: "https://disabled.com", Enabled: true, + }) + require.NoError(t, err) + + // Disable the second provider via Update (GORM default:true skips zero values on Create) + err = providerRepo.Update("disabled-provider", map[string]interface{}{"enabled": false}) + require.NoError(t, err) + + // Create models (all start enabled due to gorm:"default:true") + err = modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "enabled-provider", ModelName: "gpt-4", Enabled: true}) + require.NoError(t, err) + err = modelRepo.Create(&domain.Model{ID: "m2", ProviderID: "enabled-provider", ModelName: "gpt-3.5", Enabled: true}) + require.NoError(t, err) + err = modelRepo.Create(&domain.Model{ID: "m3", ProviderID: "disabled-provider", ModelName: "claude-3", Enabled: true}) + require.NoError(t, err) + err = modelRepo.Create(&domain.Model{ID: "m4", ProviderID: "disabled-provider", ModelName: "claude-3.5", Enabled: true}) + require.NoError(t, err) + + // Disable m2 via Update + err = modelRepo.Update("m2", map[string]interface{}{"enabled": false}) + require.NoError(t, err) + + // ListEnabled should only return models where both model and provider are enabled: + // - m1: enabled provider + enabled model -> returned + // - m2: enabled provider + disabled model -> filtered out + // - m3: disabled provider + enabled model -> filtered out + // - m4: disabled provider + enabled model -> filtered out + enabled, err := modelRepo.ListEnabled() + require.NoError(t, err) + require.Len(t, enabled, 1) + assert.Equal(t, "m1", enabled[0].ID) + assert.Equal(t, "enabled-provider", enabled[0].ProviderID) + assert.Equal(t, "gpt-4", enabled[0].ModelName) +} + func TestModelRepository_Update(t *testing.T) { db := setupTestDB(t) repo := NewModelRepository(db) diff --git a/backend/internal/service/model_service.go b/backend/internal/service/model_service.go index e927abb..032e0b4 100644 --- a/backend/internal/service/model_service.go +++ b/backend/internal/service/model_service.go @@ -7,6 +7,7 @@ type ModelService interface { Create(model *domain.Model) error Get(id string) (*domain.Model, error) List(providerID string) ([]domain.Model, error) + ListEnabled() ([]domain.Model, error) Update(id string, updates map[string]interface{}) error Delete(id string) error } diff --git a/backend/internal/service/model_service_impl.go b/backend/internal/service/model_service_impl.go index 990c6c2..7ab4c3a 100644 --- a/backend/internal/service/model_service_impl.go +++ b/backend/internal/service/model_service_impl.go @@ -1,6 +1,7 @@ package service import ( + "github.com/google/uuid" appErrors "nex/backend/pkg/errors" "nex/backend/internal/domain" @@ -17,11 +18,18 @@ func NewModelService(modelRepo repository.ModelRepository, providerRepo reposito } func (s *modelService) Create(model *domain.Model) error { - // Verify provider exists - _, err := s.providerRepo.GetByID(model.ProviderID) - if err != nil { + // 校验供应商存在 + if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil { return appErrors.ErrProviderNotFound } + + // 联合唯一校验:同一供应商下 model_name 不重复 + if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil { + return err + } + + // 自动生成 UUID 作为 id + model.ID = uuid.New().String() model.Enabled = true return s.modelRepo.Create(model) } @@ -34,17 +42,57 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) { return s.modelRepo.List(providerID) } +func (s *modelService) ListEnabled() ([]domain.Model, error) { + return s.modelRepo.ListEnabled() +} + func (s *modelService) Update(id string, updates map[string]interface{}) error { - // If updating provider_id, verify new provider exists + // 获取当前模型 + current, err := s.modelRepo.GetByID(id) + if err != nil { + return appErrors.ErrModelNotFound + } + + // 如果更新 provider_id,校验新供应商存在 if providerID, ok := updates["provider_id"].(string); ok { - _, err := s.providerRepo.GetByID(providerID) - if err != nil { + if _, err := s.providerRepo.GetByID(providerID); err != nil { return appErrors.ErrProviderNotFound } } + + // 确定更新后的 provider_id 和 model_name + newProviderID := current.ProviderID + if v, ok := updates["provider_id"].(string); ok { + newProviderID = v + } + newModelName := current.ModelName + if v, ok := updates["model_name"].(string); ok { + newModelName = v + } + + // 如果 provider_id 或 model_name 发生变化,校验联合唯一 + if newProviderID != current.ProviderID || newModelName != current.ModelName { + if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil { + return err + } + } + return s.modelRepo.Update(id, updates) } func (s *modelService) Delete(id string) error { return s.modelRepo.Delete(id) } + +// checkDuplicateModelName 校验同一供应商下 model_name 是否重复 +// excludeID 用于更新时排除自身 +func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error { + existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) + if err != nil { + return nil // 未找到,不重复 + } + if excludeID != "" && existing.ID == excludeID { + return nil // 排除自身 + } + return appErrors.ErrDuplicateModel +} diff --git a/backend/internal/service/provider_service.go b/backend/internal/service/provider_service.go index fdebc7c..2ed3f76 100644 --- a/backend/internal/service/provider_service.go +++ b/backend/internal/service/provider_service.go @@ -9,4 +9,7 @@ type ProviderService interface { List() ([]domain.Provider, error) Update(id string, updates map[string]interface{}) error Delete(id string) error + // 统一模型 ID 相关方法 + ListEnabledModels() ([]domain.Model, error) + GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) } diff --git a/backend/internal/service/provider_service_impl.go b/backend/internal/service/provider_service_impl.go index b34883a..28a992a 100644 --- a/backend/internal/service/provider_service_impl.go +++ b/backend/internal/service/provider_service_impl.go @@ -1,21 +1,35 @@ package service import ( + "strings" + + "nex/backend/pkg/modelid" + "nex/backend/internal/domain" "nex/backend/internal/repository" + appErrors "nex/backend/pkg/errors" ) type providerService struct { providerRepo repository.ProviderRepository + modelRepo repository.ModelRepository } -func NewProviderService(providerRepo repository.ProviderRepository) ProviderService { - return &providerService{providerRepo: providerRepo} +func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService { + return &providerService{providerRepo: providerRepo, modelRepo: modelRepo} } func (s *providerService) Create(provider *domain.Provider) error { + // 校验 provider_id 字符集 + if err := modelid.ValidateProviderID(provider.ID); err != nil { + return appErrors.ErrInvalidProviderID + } provider.Enabled = true - return s.providerRepo.Create(provider) + err := s.providerRepo.Create(provider) + if err != nil && isUniqueConstraintError(err) { + return appErrors.ErrConflict + } + return err } func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) { @@ -41,9 +55,31 @@ func (s *providerService) List() ([]domain.Provider, error) { } func (s *providerService) Update(id string, updates map[string]interface{}) error { + if _, ok := updates["id"]; ok { + return appErrors.ErrImmutableField + } return s.providerRepo.Update(id, updates) } func (s *providerService) Delete(id string) error { return s.providerRepo.Delete(id) } + +// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合) +func (s *providerService) ListEnabledModels() ([]domain.Model, error) { + return s.modelRepo.ListEnabled() +} + +// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询) +func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) { + return s.modelRepo.FindByProviderAndModelName(providerID, modelName) +} + +// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误 +func isUniqueConstraintError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate") +} diff --git a/backend/internal/service/routing_service.go b/backend/internal/service/routing_service.go index 85db8a4..ca58760 100644 --- a/backend/internal/service/routing_service.go +++ b/backend/internal/service/routing_service.go @@ -4,5 +4,5 @@ import "nex/backend/internal/domain" // RoutingService 路由服务接口 type RoutingService interface { - Route(modelName string) (*domain.RouteResult, error) + RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) } diff --git a/backend/internal/service/routing_service_impl.go b/backend/internal/service/routing_service_impl.go index 482c136..39cf407 100644 --- a/backend/internal/service/routing_service_impl.go +++ b/backend/internal/service/routing_service_impl.go @@ -16,8 +16,8 @@ func NewRoutingService(modelRepo repository.ModelRepository, providerRepo reposi return &routingService{modelRepo: modelRepo, providerRepo: providerRepo} } -func (s *routingService) Route(modelName string) (*domain.RouteResult, error) { - model, err := s.modelRepo.GetByModelName(modelName) +func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) { + model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) if err != nil { return nil, appErrors.ErrModelNotFound } diff --git a/backend/internal/service/service_supplemental_test.go b/backend/internal/service/service_supplemental_test.go index 5501cb2..5ba9e02 100644 --- a/backend/internal/service/service_supplemental_test.go +++ b/backend/internal/service/service_supplemental_test.go @@ -13,7 +13,8 @@ import ( func TestProviderService_Update(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) - svc := NewProviderService(repo) + modelRepo := repository.NewModelRepository(db) + svc := NewProviderService(repo, modelRepo) svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}) @@ -28,7 +29,8 @@ func TestProviderService_Update(t *testing.T) { func TestProviderService_Update_NotFound(t *testing.T) { db := setupServiceTestDB(t) repo := repository.NewProviderRepository(db) - svc := NewProviderService(repo) + modelRepo := repository.NewModelRepository(db) + svc := NewProviderService(repo, modelRepo) err := svc.Update("nonexistent", map[string]interface{}{"name": "test"}) assert.Error(t, err) @@ -41,11 +43,12 @@ func TestModelService_Get(t *testing.T) { svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) - svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) + model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} + require.NoError(t, svc.Create(model)) - model, err := svc.Get("m1") + result, err := svc.Get(model.ID) require.NoError(t, err) - assert.Equal(t, "gpt-4", model.ModelName) + assert.Equal(t, "gpt-4", result.ModelName) } func TestModelService_Update(t *testing.T) { @@ -55,14 +58,15 @@ func TestModelService_Update(t *testing.T) { svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) - svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) + model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} + require.NoError(t, svc.Create(model)) - err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"}) + err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"}) require.NoError(t, err) - model, err := svc.Get("m1") + result, err := svc.Get(model.ID) require.NoError(t, err) - assert.Equal(t, "gpt-4o", model.ModelName) + assert.Equal(t, "gpt-4o", result.ModelName) } func TestModelService_Update_ProviderID_Invalid(t *testing.T) { @@ -72,9 +76,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) { svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) - svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) + model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} + require.NoError(t, svc.Create(model)) - err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"}) + err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"}) assert.Error(t, err) } @@ -85,12 +90,13 @@ func TestModelService_Delete(t *testing.T) { svc := NewModelService(modelRepo, providerRepo) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) - svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) + model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"} + require.NoError(t, svc.Create(model)) - err := svc.Delete("m1") + err := svc.Delete(model.ID) require.NoError(t, err) - _, err = svc.Get("m1") + _, err = svc.Get(model.ID) assert.Error(t, err) } diff --git a/backend/internal/service/service_test.go b/backend/internal/service/service_test.go index 8ba1dd6..224db39 100644 --- a/backend/internal/service/service_test.go +++ b/backend/internal/service/service_test.go @@ -1,8 +1,10 @@ package service import ( + "errors" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/driver/sqlite" @@ -11,6 +13,7 @@ import ( "nex/backend/internal/config" "nex/backend/internal/domain" "nex/backend/internal/repository" + appErrors "nex/backend/pkg/errors" ) func setupServiceTestDB(t *testing.T) *gorm.DB { @@ -29,80 +32,106 @@ func setupServiceTestDB(t *testing.T) *gorm.DB { return db } -// ============ ProviderService 测试 ============ +// ============ RoutingService - RouteByModelName 测试 ============ -func TestProviderService_Create(t *testing.T) { +func TestRoutingService_RouteByModelName_Success(t *testing.T) { db := setupServiceTestDB(t) - repo := repository.NewProviderRepository(db) - svc := NewProviderService(repo) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) - provider := &domain.Provider{ - ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", - } - err := svc.Create(provider) + // 创建供应商和模型 + providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) + modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) + + result, err := svc.RouteByModelName("openai", "gpt-4") require.NoError(t, err) - assert.True(t, provider.Enabled) + assert.Equal(t, "openai", result.Provider.ID) + assert.Equal(t, "gpt-4", result.Model.ModelName) } -func TestProviderService_Get_MaskKey(t *testing.T) { +func TestRoutingService_RouteByModelName_NotFound(t *testing.T) { db := setupServiceTestDB(t) - repo := repository.NewProviderRepository(db) - svc := NewProviderService(repo) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) - svc.Create(&domain.Provider{ - ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com", - }) - - result, err := svc.Get("p1", true) - require.NoError(t, err) - assert.Equal(t, "***2345", result.APIKey) - - result, err = svc.Get("p1", false) - require.NoError(t, err) - assert.Equal(t, "sk-long-api-key-12345", result.APIKey) + _, err := svc.RouteByModelName("openai", "nonexistent-model") + assert.True(t, errors.Is(err, appErrors.ErrModelNotFound)) } -func TestProviderService_List(t *testing.T) { +func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) { db := setupServiceTestDB(t) - repo := repository.NewProviderRepository(db) - svc := NewProviderService(repo) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) - svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"}) - svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"}) + // 创建启用的供应商和禁用的模型 + providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) + modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) + modelRepo.Update("m1", map[string]interface{}{"enabled": false}) - providers, err := svc.List() - require.NoError(t, err) - assert.Len(t, providers, 2) - assert.Contains(t, providers[0].APIKey, "***") + _, err := svc.RouteByModelName("openai", "gpt-4") + assert.True(t, errors.Is(err, appErrors.ErrModelDisabled)) } -func TestProviderService_Delete(t *testing.T) { +func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) { db := setupServiceTestDB(t) - repo := repository.NewProviderRepository(db) - svc := NewProviderService(repo) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewRoutingService(modelRepo, providerRepo) - svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}) - err := svc.Delete("p1") - require.NoError(t, err) + // 创建启用的供应商和模型,然后禁用供应商 + providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}) + modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}) + providerRepo.Update("openai", map[string]interface{}{"enabled": false}) - _, err = svc.Get("p1", false) - assert.Error(t, err) + _, err := svc.RouteByModelName("openai", "gpt-4") + assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled)) } -// ============ ModelService 测试 ============ +// ============ ModelService - Create with UUID 测试 ============ -func TestModelService_Create(t *testing.T) { +func TestModelService_Create_GeneratesUUID(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) - providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) + providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) - model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"} + model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} err := svc.Create(model) require.NoError(t, err) - assert.True(t, model.Enabled) + + // 验证返回的 model 拥有有效的 UUID + assert.NotEmpty(t, model.ID) + _, err = uuid.Parse(model.ID) + assert.NoError(t, err, "model.ID should be a valid UUID") + + // 通过 Get 验证持久化 + stored, err := svc.Get(model.ID) + require.NoError(t, err) + assert.Equal(t, model.ID, stored.ID) + assert.Equal(t, "gpt-4", stored.ModelName) +} + +func TestModelService_Create_DuplicateModelName(t *testing.T) { + db := setupServiceTestDB(t) + providerRepo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewModelService(modelRepo, providerRepo) + + providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) + + model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} + err := svc.Create(model1) + require.NoError(t, err) + + // 使用相同的 (providerID, modelName) 创建第二个模型应失败 + model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} + err = svc.Create(model2) + assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel)) } func TestModelService_Create_ProviderNotFound(t *testing.T) { @@ -111,160 +140,135 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) { modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) - model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"} + model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"} err := svc.Create(model) - assert.Error(t, err) + assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound)) } -func TestModelService_List(t *testing.T) { +// ============ ProviderService - Create with validation 测试 ============ + +func TestProviderService_Create_InvalidID(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewProviderService(repo, modelRepo) + + provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} + err := svc.Create(provider) + assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID)) +} + +func TestProviderService_Create_ValidID(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewProviderService(repo, modelRepo) + + provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} + err := svc.Create(provider) + require.NoError(t, err) + assert.Equal(t, "openai", provider.ID) + assert.True(t, provider.Enabled) +} + +// ============ ModelService - Update with duplicate check 测试 ============ + +func TestModelService_Update_DuplicateModelName(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) svc := NewModelService(modelRepo, providerRepo) - providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) - svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) - svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}) + providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) + providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}) - models, err := svc.List("p1") + model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} + err := svc.Create(model1) require.NoError(t, err) - assert.Len(t, models, 2) + + model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"} + err = svc.Create(model2) + require.NoError(t, err) + + // 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突 + err = svc.Update(model2.ID, map[string]interface{}{ + "provider_id": "openai", + "model_name": "gpt-4", + }) + assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel)) } -// ============ RoutingService 测试 ============ - -func TestRoutingService_Route(t *testing.T) { +func TestModelService_Update_ModelNotFound(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) + svc := NewModelService(modelRepo, providerRepo) - providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) - modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) - - result, err := svc.Route("gpt-4") - require.NoError(t, err) - assert.Equal(t, "p1", result.Provider.ID) - assert.Equal(t, "gpt-4", result.Model.ModelName) + err := svc.Update("nonexistent-id", map[string]interface{}{ + "model_name": "gpt-4", + }) + assert.True(t, errors.Is(err, appErrors.ErrModelNotFound)) } -func TestRoutingService_Route_ModelNotFound(t *testing.T) { +func TestModelService_Update_Success(t *testing.T) { db := setupServiceTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) + svc := NewModelService(modelRepo, providerRepo) - _, err := svc.Route("nonexistent-model") - assert.Error(t, err) -} + providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}) -func TestRoutingService_Route_ModelDisabled(t *testing.T) { - db := setupServiceTestDB(t) - providerRepo := repository.NewProviderRepository(db) - modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) - - providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) - // 先创建启用的模型,然后通过 Update 禁用 - modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) - modelRepo.Update("m1", map[string]interface{}{"enabled": false}) - - _, err := svc.Route("gpt-4") - assert.Error(t, err) -} - -func TestRoutingService_Route_ProviderDisabled(t *testing.T) { - db := setupServiceTestDB(t) - providerRepo := repository.NewProviderRepository(db) - modelRepo := repository.NewModelRepository(db) - svc := NewRoutingService(modelRepo, providerRepo) - - // 先创建启用的 provider,然后禁用 - providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) - providerRepo.Update("p1", map[string]interface{}{"enabled": false}) - modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) - - _, err := svc.Route("gpt-4") - assert.Error(t, err) -} - -// ============ StatsService 测试 ============ - -func TestStatsService_RecordAndGet(t *testing.T) { - db := setupServiceTestDB(t) - statsRepo := repository.NewStatsRepository(db) - svc := NewStatsService(statsRepo) - - err := svc.Record("p1", "gpt-4") + model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"} + err := svc.Create(model) require.NoError(t, err) - stats, err := svc.Get("p1", "", nil, nil) + // 更新 model_name 为不冲突的值 + err = svc.Update(model.ID, map[string]interface{}{ + "model_name": "gpt-4-turbo", + }) require.NoError(t, err) - assert.Len(t, stats, 1) + + updated, err := svc.Get(model.ID) + require.NoError(t, err) + assert.Equal(t, "gpt-4-turbo", updated.ModelName) } -func TestStatsService_Aggregate_ByProvider(t *testing.T) { - statsRepo := repository.NewStatsRepository(nil) - svc := NewStatsService(statsRepo) +// ============ ProviderService - Update immutable ID 测试 ============ - stats := []domain.UsageStats{ - {ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, - {ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5}, - {ProviderID: "p2", ModelName: "claude-3", RequestCount: 8}, - } +func TestProviderService_Update_ImmutableID(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewProviderService(repo, modelRepo) - result := svc.Aggregate(stats, "provider") - assert.Len(t, result, 2) + provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} + err := svc.Create(provider) + require.NoError(t, err) - p1Count := 0 - p2Count := 0 - for _, r := range result { - if r["provider_id"] == "p1" { - p1Count = r["request_count"].(int) - } - if r["provider_id"] == "p2" { - p2Count = r["request_count"].(int) - } - } - assert.Equal(t, 15, p1Count) - assert.Equal(t, 8, p2Count) + // 尝试更新 id 字段 + err = svc.Update("openai", map[string]interface{}{ + "id": "new-id", + }) + assert.True(t, errors.Is(err, appErrors.ErrImmutableField)) } -func TestStatsService_Aggregate_ByDate(t *testing.T) { - statsRepo := repository.NewStatsRepository(nil) - svc := NewStatsService(statsRepo) +func TestProviderService_Update_Success(t *testing.T) { + db := setupServiceTestDB(t) + repo := repository.NewProviderRepository(db) + modelRepo := repository.NewModelRepository(db) + svc := NewProviderService(repo, modelRepo) - stats := []domain.UsageStats{ - {ProviderID: "p1", RequestCount: 10}, - {ProviderID: "p2", RequestCount: 5}, - } + provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"} + err := svc.Create(provider) + require.NoError(t, err) - result := svc.Aggregate(stats, "date") - assert.Len(t, result, 1) - assert.Equal(t, 15, result[0]["request_count"]) -} - -func TestStatsService_Aggregate_ByModel(t *testing.T) { - statsRepo := repository.NewStatsRepository(nil) - svc := NewStatsService(statsRepo) - - stats := []domain.UsageStats{ - {ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10}, - {ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5}, - {ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8}, - {ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3}, - } - - result := svc.Aggregate(stats, "model") - assert.Len(t, result, 3) - - // 验证每个 provider/model 组合的计数 - counts := make(map[string]int) - for _, r := range result { - key := r["provider_id"].(string) + "/" + r["model_name"].(string) - counts[key] = r["request_count"].(int) - } - assert.Equal(t, 13, counts["openai/gpt-4"]) - assert.Equal(t, 5, counts["openai/gpt-3.5"]) - assert.Equal(t, 8, counts["anthropic/claude-3"]) + // 更新 name + err = svc.Update("openai", map[string]interface{}{ + "name": "OpenAI Updated", + }) + require.NoError(t, err) + + updated, err := svc.Get("openai", false) + require.NoError(t, err) + assert.Equal(t, "OpenAI Updated", updated.Name) } diff --git a/backend/migrations/20260401000002_add_indexes.sql b/backend/migrations/20260401000002_add_indexes.sql deleted file mode 100644 index a3900ad..0000000 --- a/backend/migrations/20260401000002_add_indexes.sql +++ /dev/null @@ -1,9 +0,0 @@ --- +goose Up -CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id); -CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name); -CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date); - --- +goose Down -DROP INDEX IF EXISTS idx_usage_stats_provider_model_date; -DROP INDEX IF EXISTS idx_models_model_name; -DROP INDEX IF EXISTS idx_models_provider_id; diff --git a/backend/migrations/20260419000001_add_provider_protocol.sql b/backend/migrations/20260419000001_add_provider_protocol.sql deleted file mode 100644 index 6ed08b7..0000000 --- a/backend/migrations/20260419000001_add_provider_protocol.sql +++ /dev/null @@ -1,6 +0,0 @@ --- +goose Up -ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai'; - --- +goose Down --- SQLite 不支持 DROP COLUMN(3.35.0 之前),但 goose 的 Down 通常不需要 -CREATE TABLE providers_backup AS SELECT id, name, api_key, base_url, enabled, created_at, updated_at FROM providers; diff --git a/backend/migrations/20260401000001_initial_schema.sql b/backend/migrations/20260421000001_initial_schema.sql similarity index 56% rename from backend/migrations/20260401000001_initial_schema.sql rename to backend/migrations/20260421000001_initial_schema.sql index 5c94dfa..2076207 100644 --- a/backend/migrations/20260401000001_initial_schema.sql +++ b/backend/migrations/20260421000001_initial_schema.sql @@ -1,9 +1,13 @@ -- +goose Up +-- 统一初始迁移:providers、models、usage_stats 完整表结构 +-- models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 联合唯一约束 + CREATE TABLE IF NOT EXISTS providers ( id TEXT PRIMARY KEY, name TEXT NOT NULL, api_key TEXT NOT NULL, base_url TEXT NOT NULL, + protocol TEXT DEFAULT 'openai', enabled INTEGER DEFAULT 1, created_at DATETIME, updated_at DATETIME @@ -15,7 +19,8 @@ CREATE TABLE IF NOT EXISTS models ( model_name TEXT NOT NULL, enabled INTEGER DEFAULT 1, created_at DATETIME, - FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE + FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE, + UNIQUE(provider_id, model_name) ); CREATE TABLE IF NOT EXISTS usage_stats ( @@ -27,7 +32,14 @@ CREATE TABLE IF NOT EXISTS usage_stats ( UNIQUE(provider_id, model_name, date) ); +CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id); +CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name); +CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date); + -- +goose Down +DROP INDEX IF EXISTS idx_usage_stats_provider_model_date; +DROP INDEX IF EXISTS idx_models_model_name; +DROP INDEX IF EXISTS idx_models_provider_id; DROP TABLE IF EXISTS usage_stats; DROP TABLE IF EXISTS models; DROP TABLE IF EXISTS providers; diff --git a/backend/pkg/errors/errors.go b/backend/pkg/errors/errors.go index eba147a..4fee2c9 100644 --- a/backend/pkg/errors/errors.go +++ b/backend/pkg/errors/errors.go @@ -49,17 +49,20 @@ func NewAppError(code, message string, httpStatus int) *AppError { // Predefined errors var ( - ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound) - ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound) - ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound) - ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound) - ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest) - ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError) - ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError) - ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict) - ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError) - ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway) - ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway) + ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound) + ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound) + ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound) + ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound) + ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest) + ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError) + ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError) + ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict) + ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError) + ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway) + ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway) + ErrInvalidProviderID = NewAppError("invalid_provider_id", "供应商 ID 仅允许字母、数字、下划线,长度 1-64", http.StatusBadRequest) + ErrDuplicateModel = NewAppError("duplicate_model", "同一供应商下模型名称已存在", http.StatusConflict) + ErrImmutableField = NewAppError("immutable_field", "供应商 ID 不允许修改", http.StatusBadRequest) ) // AsAppError 尝试将 error 转换为 *AppError diff --git a/backend/pkg/logger/logger.go b/backend/pkg/logger/logger.go index 6dec031..186df9b 100644 --- a/backend/pkg/logger/logger.go +++ b/backend/pkg/logger/logger.go @@ -9,6 +9,14 @@ import ( "go.uber.org/zap/zapcore" ) +// stdoutWriter 包装 os.Stdout,忽略 Sync() 错误。 +// 在非 TTY 环境(如 go test)中,os.Stdout 被重定向为 pipe, +// 底层 fsync 会返回 "bad file descriptor"。zap 社区标准做法。 +type stdoutWriter struct{} + +func (stdoutWriter) Write(p []byte) (int, error) { return os.Stdout.Write(p) } +func (stdoutWriter) Sync() error { return nil } + // Config 日志配置 type Config struct { Level string // 日志级别: debug, info, warn, error @@ -46,7 +54,7 @@ func New(cfg Config) (*zap.Logger, error) { stdoutCore := zapcore.NewCore( stdoutEncoder, - zapcore.AddSync(os.Stdout), + zapcore.AddSync(stdoutWriter{}), level, ) diff --git a/backend/pkg/modelid/model_id.go b/backend/pkg/modelid/model_id.go new file mode 100644 index 0000000..8e56d4b --- /dev/null +++ b/backend/pkg/modelid/model_id.go @@ -0,0 +1,63 @@ +package modelid + +import ( + "errors" + "regexp" + "strings" +) + +var providerIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) + +// ParseUnifiedModelID 将 "provider_id/model_name" 格式的字符串解析为 providerID 和 modelName +// 在第一个 "/" 处分割,model_name 可以包含 "/" +func ParseUnifiedModelID(id string) (providerID, modelName string, err error) { + if id == "" { + return "", "", errors.New("统一模型 ID 不能为空") + } + + parts := strings.SplitN(id, "/", 2) + if len(parts) != 2 { + return "", "", errors.New("统一模型 ID 格式错误,缺少分隔符 \"/\"") + } + + providerID = parts[0] + modelName = parts[1] + + if providerID == "" { + return "", "", errors.New("provider_id 不能为空") + } + if modelName == "" { + return "", "", errors.New("model_name 不能为空") + } + + if !providerIDRegex.MatchString(providerID) { + return "", "", errors.New("provider_id 仅允许字母、数字、下划线") + } + + return providerID, modelName, nil +} + +// FormatUnifiedModelID 将 providerID 和 modelName 组合格式化为统一模型 ID +func FormatUnifiedModelID(providerID, modelName string) string { + return providerID + "/" + modelName +} + +// ValidateProviderID 校验 providerID 仅包含字母、数字、下划线,长度 1-64 +func ValidateProviderID(id string) error { + if id == "" { + return errors.New("provider_id 不能为空") + } + if len(id) > 64 { + return errors.New("provider_id 长度不能超过 64 个字符") + } + if !providerIDRegex.MatchString(id) { + return errors.New("provider_id 仅允许字母、数字、下划线") + } + return nil +} + +// IsValidUnifiedModelID 判断字符串是否为合法的统一模型 ID 格式 +func IsValidUnifiedModelID(id string) bool { + _, _, err := ParseUnifiedModelID(id) + return err == nil +} diff --git a/backend/pkg/modelid/model_id_test.go b/backend/pkg/modelid/model_id_test.go new file mode 100644 index 0000000..7889923 --- /dev/null +++ b/backend/pkg/modelid/model_id_test.go @@ -0,0 +1,96 @@ +package modelid + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseUnifiedModelID_StandardFormat(t *testing.T) { + providerID, modelName, err := ParseUnifiedModelID("openai/gpt-4") + assert.NoError(t, err) + assert.Equal(t, "openai", providerID) + assert.Equal(t, "gpt-4", modelName) +} + +func TestParseUnifiedModelID_ModelNameWithSlashes(t *testing.T) { + providerID, modelName, err := ParseUnifiedModelID("azure/accounts/org-123/models/gpt-4") + assert.NoError(t, err) + assert.Equal(t, "azure", providerID) + assert.Equal(t, "accounts/org-123/models/gpt-4", modelName) +} + +func TestParseUnifiedModelID_MissingSeparator(t *testing.T) { + _, _, err := ParseUnifiedModelID("gpt-4") + assert.Error(t, err) +} + +func TestParseUnifiedModelID_EmptyString(t *testing.T) { + _, _, err := ParseUnifiedModelID("") + assert.Error(t, err) +} + +func TestParseUnifiedModelID_OnlySeparator(t *testing.T) { + tests := []string{"/model", "provider/", "/"} + for _, tc := range tests { + _, _, err := ParseUnifiedModelID(tc) + assert.Error(t, err, "输入 %q 应返回错误", tc) + } +} + +func TestParseUnifiedModelID_InvalidProviderID(t *testing.T) { + tests := []string{ + "open-ai/gpt-4", + "open.ai/gpt-4", + "供应商/gpt-4", + "open ai/gpt-4", + } + for _, tc := range tests { + _, _, err := ParseUnifiedModelID(tc) + assert.Error(t, err, "providerID 含非法字符 %q 应返回错误", tc) + } +} + +func TestFormatUnifiedModelID(t *testing.T) { + assert.Equal(t, "openai/gpt-4", FormatUnifiedModelID("openai", "gpt-4")) + assert.Equal(t, "anthropic/claude-3-opus", FormatUnifiedModelID("anthropic", "claude-3-opus")) +} + +func TestValidateProviderID_Valid(t *testing.T) { + validIDs := []string{"openai", "deep_seek", "provider01", "OpenAI"} + for _, id := range validIDs { + assert.NoError(t, ValidateProviderID(id), "%q 应校验通过", id) + } +} + +func TestValidateProviderID_InvalidChars(t *testing.T) { + invalidIDs := []string{"open-ai", "open.ai", "open ai", "供应商", "open/ai"} + for _, id := range invalidIDs { + assert.Error(t, ValidateProviderID(id), "%q 应校验失败", id) + } +} + +func TestValidateProviderID_Empty(t *testing.T) { + assert.Error(t, ValidateProviderID("")) +} + +func TestValidateProviderID_TooLong(t *testing.T) { + longID := strings.Repeat("a", 65) + assert.Error(t, ValidateProviderID(longID)) + + exactly64 := strings.Repeat("a", 64) + assert.NoError(t, ValidateProviderID(exactly64)) +} + +func TestIsValidUnifiedModelID(t *testing.T) { + assert.True(t, IsValidUnifiedModelID("openai/gpt-4")) + assert.True(t, IsValidUnifiedModelID("anthropic/claude-3-opus-20240229")) + assert.True(t, IsValidUnifiedModelID("azure/accounts/org/models/gpt-4")) + + assert.False(t, IsValidUnifiedModelID("")) + assert.False(t, IsValidUnifiedModelID("gpt-4")) + assert.False(t, IsValidUnifiedModelID("open-ai/gpt-4")) + assert.False(t, IsValidUnifiedModelID("/model")) + assert.False(t, IsValidUnifiedModelID("provider/")) +} diff --git a/backend/tests/helpers.go b/backend/tests/helpers.go index 125bc63..19160a2 100644 --- a/backend/tests/helpers.go +++ b/backend/tests/helpers.go @@ -3,6 +3,7 @@ package tests import ( "fmt" "testing" + "time" "nex/backend/internal/config" @@ -11,26 +12,36 @@ import ( "gorm.io/gorm" ) -// SetupTestDB initializes a temporary SQLite database with auto-migration. +// SetupTestDB initializes an in-memory SQLite database with auto-migration. +// Uses :memory: mode with MaxOpenConns(1) to ensure all operations share the +// same connection, avoiding "database is closed" errors from connection pool. +// Enables foreign key constraints for SQLite. func SetupTestDB(t *testing.T) *gorm.DB { t.Helper() - dir := t.TempDir() - dsn := dir + "/test.db" - - db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + db, err := gorm.Open(sqlite.Open(":memory:?_foreign_keys=on"), &gorm.Config{}) assert.NoError(t, err, "failed to open test database") + // 限制为单连接,确保 :memory: 数据库不被连接池丢弃 + sqlDB, err := db.DB() + assert.NoError(t, err, "failed to get underlying sql.DB") + sqlDB.SetMaxOpenConns(1) + sqlDB.SetConnMaxLifetime(0) + err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) assert.NoError(t, err, "failed to auto-migrate test database") return db } -// CleanupTestDB closes the database and removes the temp database file. +// CleanupTestDB closes the database after a brief delay to allow async +// goroutines (e.g. stats recording) to finish. func CleanupTestDB(t *testing.T, db *gorm.DB) { t.Helper() + // 等待异步 goroutine(如 statsService.Record)完成 + time.Sleep(50 * time.Millisecond) + sqlDB, err := db.DB() assert.NoError(t, err, "failed to get underlying sql.DB") @@ -57,7 +68,8 @@ func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider { } // CreateTestModel creates a test model and returns it. -func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) config.Model { +// Does NOT assert on error - returns the model and error for caller to verify. +func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) (config.Model, error) { t.Helper() model := config.Model{ @@ -68,7 +80,5 @@ func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, mo } err := db.Create(&model).Error - assert.NoError(t, err, "failed to create test model") - - return model + return model, err } diff --git a/backend/tests/integration/conversion_test.go b/backend/tests/integration/conversion_test.go index 1cc345b..b5707f6 100644 --- a/backend/tests/integration/conversion_test.go +++ b/backend/tests/integration/conversion_test.go @@ -14,10 +14,8 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gorm.io/driver/sqlite" "gorm.io/gorm" - "nex/backend/internal/config" "nex/backend/internal/conversion" "nex/backend/internal/conversion/anthropic" openaiConv "nex/backend/internal/conversion/openai" @@ -43,11 +41,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server) w.Write([]byte(`{"error":"not mocked"}`)) })) - dir := t.TempDir() - db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) - require.NoError(t, err) - err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) - require.NoError(t, err) + db := setupTestDB(t) t.Cleanup(func() { sqlDB, _ := db.DB() if sqlDB != nil { @@ -60,7 +54,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server) modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - providerService := service.NewProviderService(providerRepo) + providerService := service.NewProviderService(providerRepo, modelRepo) modelService := service.NewModelService(modelRepo, providerRepo) routingService := service.NewRoutingService(modelRepo, providerRepo) statsService := service.NewStatsService(statsRepo) @@ -125,7 +119,7 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m require.Equal(t, 201, w.Code) modelBody, _ := json.Marshal(map[string]string{ - "id": modelName, + "provider_id": providerID, "model_name": modelName, }) @@ -156,7 +150,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) { "id": "msg_test", "type": "message", "role": "assistant", - "model": "claude-3-opus", + "model": "anthropic_p/claude-3-opus", "content": []map[string]any{ {"type": "text", "text": "Hello from Anthropic!"}, }, @@ -170,11 +164,11 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) { json.NewEncoder(w).Encode(resp) }) - createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) + createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) // 使用 OpenAI 格式发送请求 openaiReq := map[string]any{ - "model": "claude-3-opus", + "model": "anthropic_p/claude-3-opus", "messages": []map[string]any{ {"role": "user", "content": "Hello"}, }, @@ -233,10 +227,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) { json.NewEncoder(w).Encode(resp) }) - createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) anthropicReq := map[string]any{ - "model": "gpt-4", + "model": "openai_p/gpt-4", "max_tokens": 1024, "messages": []map[string]any{ {"role": "user", "content": "Hello"}, @@ -273,16 +267,18 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) { body, _ := io.ReadAll(r.Body) var req map[string]any json.Unmarshal(body, &req) + // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 assert.Equal(t, "gpt-4", req["model"]) w.Header().Set("Content-Type", "application/json") + // 上游返回上游模型名 w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`)) }) - createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) reqBody := map[string]any{ - "model": "gpt-4", + "model": "openai_p/gpt-4", // 客户端发送统一 ID "messages": []map[string]any{{"role": "user", "content": "test"}}, } body, _ := json.Marshal(reqBody) @@ -293,7 +289,8 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) { r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "passthrough") + // Smart Passthrough: 响应体中的上游模型名应被改写为统一 ID + assert.Contains(t, w.Body.String(), `"model":"openai_p/gpt-4"`) } func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) { @@ -302,14 +299,21 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/v1/messages", r.URL.Path) + body, _ := io.ReadAll(r.Body) + var req map[string]any + json.Unmarshal(body, &req) + // Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名 + assert.Equal(t, "claude-3-opus", req["model"]) + + // 上游返回上游模型名 w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`)) }) - createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) + createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) reqBody := map[string]any{ - "model": "claude-3-opus", + "model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "test"}}, } @@ -321,7 +325,8 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) { r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "passthrough") + // Smart Passthrough: 响应体中的上游模型名应被改写为统一 ID + assert.Contains(t, w.Body.String(), `"model":"anthropic_p/claude-3-opus"`) } // ============ 流式转换测试 ============ @@ -349,10 +354,10 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) { } }) - createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) + createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL) openaiReq := map[string]any{ - "model": "claude-3-opus", + "model": "anthropic_p/claude-3-opus", "messages": []map[string]any{{"role": "user", "content": "Hello"}}, "stream": true, } @@ -390,10 +395,10 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) { } }) - createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) anthropicReq := map[string]any{ - "model": "gpt-4", + "model": "openai_p/gpt-4", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "Hello"}}, "stream": true, @@ -512,7 +517,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) { // 创建带 protocol 字段的 provider providerBody := map[string]any{ - "id": "test-protocol", + "id": "test_protocol", "name": "Test Protocol", "api_key": "sk-test", "base_url": "https://test.com", @@ -533,7 +538,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) { // 获取时应包含 protocol w = httptest.NewRecorder() - req = httptest.NewRequest("GET", "/api/providers/test-protocol", nil) + req = httptest.NewRequest("GET", "/api/providers/test_protocol", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) @@ -547,7 +552,7 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) { // 不指定 protocol,默认应为 openai providerBody := map[string]any{ - "id": "default-proto", + "id": "default_proto", "name": "Default", "api_key": "sk-test", "base_url": "https://test.com", diff --git a/backend/tests/integration/e2e_conversion_test.go b/backend/tests/integration/e2e_conversion_test.go index 086e08f..cb27cc0 100644 --- a/backend/tests/integration/e2e_conversion_test.go +++ b/backend/tests/integration/e2e_conversion_test.go @@ -8,8 +8,6 @@ import ( "io" "net/http" "net/http/httptest" - "os" - "path/filepath" "strings" "testing" "time" @@ -17,10 +15,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "nex/backend/internal/config" "nex/backend/internal/conversion" "nex/backend/internal/conversion/anthropic" openaiConv "nex/backend/internal/conversion/openai" @@ -40,25 +35,20 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) { w.Write([]byte(`{"error":"not mocked"}`)) })) - dir, _ := os.MkdirTemp("", "e2e-test-*") - db, err := gorm.Open(sqlite.Open(filepath.Join(dir, "test.db")), &gorm.Config{}) - require.NoError(t, err) - err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) - require.NoError(t, err) + db := setupTestDB(t) t.Cleanup(func() { sqlDB, _ := db.DB() if sqlDB != nil { sqlDB.Close() } upstream.Close() - os.RemoveAll(dir) }) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - providerService := service.NewProviderService(providerRepo) + providerService := service.NewProviderService(providerRepo, modelRepo) modelService := service.NewModelService(modelRepo, providerRepo) routingService := service.NewRoutingService(modelRepo, providerRepo) statsService := service.NewStatsService(statsRepo) @@ -105,7 +95,7 @@ func e2eCreateProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol require.Equal(t, 201, w.Code) modelBody, _ := json.Marshal(map[string]string{ - "id": modelName, "provider_id": providerID, "model_name": modelName, + "provider_id": providerID, "model_name": modelName, }) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) @@ -178,10 +168,10 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) { }, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{ {"role": "user", "content": "你好"}, }, @@ -195,7 +185,7 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) { var resp map[string]any require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) assert.Equal(t, "chat.completion", resp["object"]) - assert.Equal(t, "gpt-4o", resp["model"]) + assert.Equal(t, "openai_p/gpt-4o", resp["model"]) choices := resp["choices"].([]any) require.Len(t, choices, 1) @@ -231,10 +221,10 @@ func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) { "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{ {"role": "system", "content": "你是编程助手"}, {"role": "user", "content": "什么是interface?"}, @@ -279,10 +269,10 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) { "usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{ {"role": "user", "content": "北京天气"}, }, @@ -335,10 +325,10 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) { "usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}}, "max_tokens": 30, }) @@ -372,10 +362,10 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) { }, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "o3", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "o3", + "model": "openai_p/o3", "messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}}, }) w := httptest.NewRecorder() @@ -413,10 +403,10 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) { "usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "做坏事"}}, }) w := httptest.NewRecorder() @@ -455,10 +445,10 @@ func TestE2E_OpenAI_Stream_Text(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "你好"}}, "stream": true, }) @@ -499,10 +489,10 @@ func TestE2E_OpenAI_Stream_ToolCalls(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "tools": []map[string]any{{ "type": "function", @@ -548,10 +538,10 @@ func TestE2E_OpenAI_Stream_WithUsage(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "hi"}}, "stream": true, }) @@ -583,10 +573,10 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) { "usage": map[string]any{"input_tokens": 15, "output_tokens": 25}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "你好"}}, }) w := httptest.NewRecorder() @@ -599,7 +589,7 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) { require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) assert.Equal(t, "message", resp["type"]) assert.Equal(t, "assistant", resp["role"]) - assert.Equal(t, "claude-opus-4-7", resp["model"]) + assert.Equal(t, "anthropic_p/claude-opus-4-7", resp["model"]) assert.Equal(t, "end_turn", resp["stop_reason"]) content := resp["content"].([]any) @@ -629,10 +619,10 @@ func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) { "usage": map[string]any{"input_tokens": 30, "output_tokens": 15}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "system": "你是编程助手", "messages": []map[string]any{{"role": "user", "content": "什么是递归?"}}, }) @@ -658,10 +648,10 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) { "usage": map[string]any{"input_tokens": 180, "output_tokens": 42}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "tools": []map[string]any{{ "name": "get_weather", "description": "获取天气", @@ -704,10 +694,10 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) { "usage": map[string]any{"input_tokens": 95, "output_tokens": 280}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 4096, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096, "messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}}, "thinking": map[string]any{"type": "enabled", "budget_tokens": 2048}, }) @@ -736,10 +726,10 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) { "usage": map[string]any{"input_tokens": 22, "output_tokens": 20}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 20, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 20, "messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}}, }) w := httptest.NewRecorder() @@ -764,10 +754,10 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) { "usage": map[string]any{"input_tokens": 22, "output_tokens": 10}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "stop_sequences": []string{"5"}, }) @@ -800,10 +790,10 @@ func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) { "usage": map[string]any{"input_tokens": 12, "output_tokens": 5}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "你好"}}, "metadata": map[string]any{"user_id": "user_12345"}, }) @@ -829,10 +819,10 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) { }, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "system": []map[string]any{{"type": "text", "text": "你是编程助手。"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}}, }) @@ -874,10 +864,10 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "你好"}}, "stream": true, }) @@ -921,10 +911,10 @@ func TestE2E_Anthropic_Stream_Thinking(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 4096, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096, "messages": []map[string]any{{"role": "user", "content": "1+1=?"}}, "thinking": map[string]any{"type": "enabled", "budget_tokens": 1024}, "stream": true, @@ -970,10 +960,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_RequestFormat(t *testing.T) { "usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-model", + "model": "anthropic_p/claude-model", "messages": []map[string]any{{"role": "user", "content": "Hello"}}, }) w := httptest.NewRecorder() @@ -1011,10 +1001,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_RequestFormat(t *testing.T) { "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4", "max_tokens": 1024, + "model": "openai_p/gpt-4", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "Hello"}}, }) w := httptest.NewRecorder() @@ -1052,10 +1042,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-model", + "model": "anthropic_p/claude-model", "messages": []map[string]any{{"role": "user", "content": "Hello"}}, "stream": true, }) @@ -1092,10 +1082,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4", "max_tokens": 1024, + "model": "openai_p/gpt-4", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "Hello"}}, "stream": true, }) @@ -1130,10 +1120,10 @@ func TestE2E_OpenAI_ErrorResponse(t *testing.T) { }, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "nonexistent", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "nonexistent", + "model": "openai_p/nonexistent", "messages": []map[string]any{{"role": "user", "content": "test"}}, }) w := httptest.NewRecorder() @@ -1157,10 +1147,10 @@ func TestE2E_Anthropic_ErrorResponse(t *testing.T) { }, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "test"}}, }) w := httptest.NewRecorder() @@ -1203,10 +1193,10 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) { "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 36, "total_tokens": 136}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}}, "tools": []map[string]any{{ "type": "function", @@ -1255,10 +1245,10 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) { "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "stop": []string{"5"}, }) @@ -1293,10 +1283,10 @@ func TestE2E_OpenAI_NonStream_ContentFilter(t *testing.T) { "usage": map[string]any{"prompt_tokens": 8, "completion_tokens": 0, "total_tokens": 8}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "危险内容"}}, }) w := httptest.NewRecorder() @@ -1325,10 +1315,10 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) { "usage": map[string]any{"input_tokens": 200, "output_tokens": 84}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}}, "tools": []map[string]any{{ "name": "get_weather", "description": "获取天气", @@ -1374,10 +1364,10 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) { "usage": map[string]any{"input_tokens": 100, "output_tokens": 30}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "现在几点了?"}}, "tools": []map[string]any{{ "name": "get_time", "description": "获取当前时间", @@ -1417,10 +1407,10 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) { "usage": map[string]any{"input_tokens": 50, "output_tokens": 10}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "system": []map[string]any{ {"type": "text", "text": "你是编程助手。"}, {"type": "text", "text": "请用中文回答。"}, @@ -1454,10 +1444,10 @@ func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) { "usage": map[string]any{"input_tokens": 150, "output_tokens": 20}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{ {"role": "user", "content": "北京天气"}, {"role": "assistant", "content": []map[string]any{ @@ -1507,10 +1497,10 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "tools": []map[string]any{{ "name": "get_weather", "description": "获取天气", @@ -1561,10 +1551,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_NonStream_ToolCalls(t *testing.T) { "usage": map[string]any{"input_tokens": 100, "output_tokens": 30}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-model", + "model": "anthropic_p/claude-model", "messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "tools": []map[string]any{{ "type": "function", @@ -1613,10 +1603,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_NonStream_Thinking(t *testing.T) { "usage": map[string]any{"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4", "max_tokens": 4096, + "model": "openai_p/gpt-4", "max_tokens": 4096, "messages": []map[string]any{{"role": "user", "content": "宇宙的答案"}}, }) w := httptest.NewRecorder() @@ -1643,10 +1633,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) { "usage": map[string]any{"input_tokens": 10, "output_tokens": 20}, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-model", + "model": "anthropic_p/claude-model", "messages": []map[string]any{{"role": "user", "content": "长文"}}, }) w := httptest.NewRecorder() @@ -1685,10 +1675,10 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) { "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{ {"role": "user", "content": "北京天气"}, {"role": "assistant", "content": nil, "tool_calls": []map[string]any{{ @@ -1732,10 +1722,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-model", + "model": "anthropic_p/claude-model", "messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "tools": []map[string]any{{ "type": "function", @@ -1781,10 +1771,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream_ToolCalls(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4", "max_tokens": 1024, + "model": "openai_p/gpt-4", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "tools": []map[string]any{{ "name": "get_weather", "description": "获取天气", @@ -1819,10 +1809,10 @@ func TestE2E_OpenAI_Upstream5xx_ErrorPassthrough(t *testing.T) { }, }) }) - e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) + e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "gpt-4o", + "model": "openai_p/gpt-4o", "messages": []map[string]any{{"role": "user", "content": "test"}}, }) w := httptest.NewRecorder() @@ -1851,10 +1841,10 @@ func TestE2E_Anthropic_Upstream5xx_ErrorPassthrough(t *testing.T) { }, }) }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "test"}}, }) w := httptest.NewRecorder() @@ -1889,10 +1879,10 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) { time.Sleep(10 * time.Millisecond) } }) - e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) + e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL) body, _ := json.Marshal(map[string]any{ - "model": "claude-opus-4-7", "max_tokens": 1024, + "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024, "messages": []map[string]any{{"role": "user", "content": "test"}}, "stream": true, }) diff --git a/backend/tests/integration/integration_test.go b/backend/tests/integration/integration_test.go index ad754d6..518bb94 100644 --- a/backend/tests/integration/integration_test.go +++ b/backend/tests/integration/integration_test.go @@ -9,11 +9,8 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gorm.io/driver/sqlite" "gorm.io/gorm" - "nex/backend/internal/config" "nex/backend/internal/domain" "nex/backend/internal/handler" "nex/backend/internal/handler/middleware" @@ -27,23 +24,13 @@ func init() { func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) { t.Helper() - dir := t.TempDir() - db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{}) - require.NoError(t, err) - err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) - require.NoError(t, err) - t.Cleanup(func() { - sqlDB, _ := db.DB() - if sqlDB != nil { - sqlDB.Close() - } - }) + db := setupTestDB(t) providerRepo := repository.NewProviderRepository(db) modelRepo := repository.NewModelRepository(db) statsRepo := repository.NewStatsRepository(db) - providerService := service.NewProviderService(providerRepo) + providerService := service.NewProviderService(providerRepo, modelRepo) modelService := service.NewModelService(modelRepo, providerRepo) _ = service.NewRoutingService(modelRepo, providerRepo) statsService := service.NewStatsService(statsRepo) @@ -97,13 +84,16 @@ func TestOpenAI_CompleteFlow(t *testing.T) { // 2. 创建 Model modelBody, _ := json.Marshal(map[string]string{ - "id": "gpt4", "provider_id": "openai", "model_name": "gpt-4", + "provider_id": "openai", "model_name": "gpt-4", }) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) + var createdModel domain.Model + json.Unmarshal(w.Body.Bytes(), &createdModel) + assert.NotEmpty(t, createdModel.ID) // 3. 列出 Provider w = httptest.NewRecorder() @@ -135,7 +125,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) { // 6. 删除 Model w = httptest.NewRecorder() - req = httptest.NewRequest("DELETE", "/api/models/gpt4", nil) + req = httptest.NewRequest("DELETE", "/api/models/"+createdModel.ID, nil) r.ServeHTTP(w, req) assert.Equal(t, 204, w.Code) @@ -160,17 +150,19 @@ func TestAnthropic_ModelCreation(t *testing.T) { assert.Equal(t, 201, w.Code) modelBody, _ := json.Marshal(map[string]string{ - "id": "claude3", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229", + "provider_id": "anthropic", "model_name": "claude-3-opus-20240229", }) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) + var createdModel domain.Model + json.Unmarshal(w.Body.Bytes(), &createdModel) // 验证创建成功 w = httptest.NewRecorder() - req = httptest.NewRequest("GET", "/api/models/claude3", nil) + req = httptest.NewRequest("GET", "/api/models/"+createdModel.ID, nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } @@ -188,7 +180,7 @@ func TestStats_RecordingAndQuery(t *testing.T) { r.ServeHTTP(w, req) modelBody, _ := json.Marshal(map[string]string{ - "id": "m1", "provider_id": "p1", "model_name": "gpt-4", + "provider_id": "p1", "model_name": "gpt-4", }) w = httptest.NewRecorder() req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) diff --git a/backend/tests/integration/testhelper.go b/backend/tests/integration/testhelper.go new file mode 100644 index 0000000..f4e3984 --- /dev/null +++ b/backend/tests/integration/testhelper.go @@ -0,0 +1,37 @@ +package integration + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "nex/backend/internal/config" +) + +// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。 +// 使用 MaxOpenConns(1) 确保 :memory: 模式不会被连接池丢弃。 +func setupTestDB(t *testing.T) *gorm.DB { + t.Helper() + + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + require.NoError(t, err) + + sqlDB, err := db.DB() + require.NoError(t, err) + sqlDB.SetMaxOpenConns(1) + sqlDB.SetConnMaxLifetime(0) + + err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) + require.NoError(t, err) + + t.Cleanup(func() { + // 等待异步 goroutine(如 statsService.Record)完成 + time.Sleep(50 * time.Millisecond) + sqlDB.Close() + }) + + return db +} diff --git a/backend/tests/migration_test.go b/backend/tests/migration_test.go new file mode 100644 index 0000000..6f33b31 --- /dev/null +++ b/backend/tests/migration_test.go @@ -0,0 +1,79 @@ +package tests + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "nex/backend/internal/config" +) + +func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) { + db := SetupTestDB(t) + defer CleanupTestDB(t, db) + + // 创建供应商 + _ = CreateTestProvider(t, db, "openai") + + // 创建模型使用 UUID 作为 id + model, err := CreateTestModel(t, db, "550e8400-e29b-41d4-a716-446655440000", "openai", "gpt-4") + require.NoError(t, err) + + // 通过 UUID 查询 + var result config.Model + require.NoError(t, db.First(&result, "id = ?", model.ID).Error) + assert.Equal(t, "gpt-4", result.ModelName) +} + +func TestMigration_UniqueProviderModelConstraint(t *testing.T) { + db := SetupTestDB(t) + defer CleanupTestDB(t, db) + + // 创建供应商 + _ = CreateTestProvider(t, db, "openai") + + // 创建第一个模型 + _, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4") + require.NoError(t, err) + + // 尝试创建相同 (provider_id, model_name) 的模型应失败 + _, err = CreateTestModel(t, db, "uuid-2", "openai", "gpt-4") + assert.Error(t, err, "UNIQUE(provider_id, model_name) 约束应阻止重复") + + // 不同 provider_id 下相同 model_name 应成功 + _ = CreateTestProvider(t, db, "anthropic") + _, err = CreateTestModel(t, db, "uuid-3", "anthropic", "gpt-4") + require.NoError(t, err, "不同 provider_id 下相同 model_name 应允许") +} + +func TestMigration_CascadeDelete(t *testing.T) { + db := SetupTestDB(t) + defer CleanupTestDB(t, db) + + // 创建供应商和模型 + provider := CreateTestProvider(t, db, "openai") + _, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4") + require.NoError(t, err) + + // 删除供应商应级联删除模型 + require.NoError(t, db.Delete(&provider).Error) + + var count int64 + db.Model(&config.Model{}).Where("provider_id = ?", "openai").Count(&count) + assert.Equal(t, int64(0), count, "删除供应商后其模型应被级联删除") +} + +func TestMigration_ModelDefaultEnabled(t *testing.T) { + db := SetupTestDB(t) + defer CleanupTestDB(t, db) + + _ = CreateTestProvider(t, db, "openai") + + // 创建模型不指定 enabled,应默认为 true + _, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4") + require.NoError(t, err) + + var result config.Model + require.NoError(t, db.First(&result, "id = ?", "uuid-1").Error) + assert.True(t, result.Enabled, "enabled 字段默认应为 true") +} diff --git a/openspec/changes/unified-model-id/.openspec.yaml b/openspec/changes/unified-model-id/.openspec.yaml deleted file mode 100644 index c4036b7..0000000 --- a/openspec/changes/unified-model-id/.openspec.yaml +++ /dev/null @@ -1,2 +0,0 @@ -schema: spec-driven -created: 2026-04-20 diff --git a/openspec/changes/unified-model-id/design.md b/openspec/changes/unified-model-id/design.md deleted file mode 100644 index eecc53a..0000000 --- a/openspec/changes/unified-model-id/design.md +++ /dev/null @@ -1,137 +0,0 @@ -## Context - -Nex 是一个 AI 网关,屏蔽多个 AI 供应商(OpenAI、Anthropic 等)的差异,提供统一的 API 接口。当前后端直接透传上游供应商的原始模型名称(如 `gpt-4`),通过 `models` 表的 `model_name` 字段路由。`models` 表的 `id` 字段当前语义是用户自定义标识符,与上游模型名 `model_name` 之间没有明确的职责分离。 - -当前架构: -- `ProxyHandler` 从请求体中提取 `model` 字段 → `RoutingService.Route(modelName)` 按 `model_name` 查询 -- `GET /v1/models` 直接透传到第一个供应商的上游接口 -- `GET /v1/models/{id}` 直接透传到上游 -- `TargetProvider.ModelName` 在 encoder 中覆盖请求体的 `model` 字段 - -## Goals / Non-Goals - -**Goals:** -- 定义统一模型 ID 格式 `provider_id/model_name`,全局唯一标识一个模型 -- 拦截 `/v1/models` 和 `/v1/models/{unified_id}` 接口,从数据库聚合返回,不再透传上游 -- 所有代理接口(Chat、Embeddings、Rerank)使用统一模型 ID 路由,响应中 `model` 字段覆写为统一 ID -- `models.id` 改为 UUID(内部标识),`models.model_name` 存储上游供应商的模型名称 -- `provider_id` 约束为 `[a-zA-Z0-9_]+`,防止特殊字符影响 URL 和 JSON 交互 -- 保持协议无关、供应商无关的设计 - -**Non-Goals:** -- 不支持供应商别名或模型别名 -- 不做上游模型列表自动同步(管理员手动配置可见模型) -- 不适配前端(后续统一适配) - -## Decisions - -### D1: 统一模型 ID 格式 — `provider_id/model_name` - -格式: `{provider_id}/{model_name}`,例如 `openai/gpt-4`、`anthropic/claude-3-opus-20240229`。 - -- 使用 `strings.SplitN(id, "/", 2)` 解析,只在第一个 `/` 处分割 -- `provider_id` 约束为 `[a-zA-Z0-9_]+`,保证不含 `/`,解析安全 -- `model_name`(上游模型名)不受字符约束,因为它不出现在管理 API 的 URL 主键中 - -选择此格式而非 `provider_id:model_name`(冒号分隔)的原因:斜杠在 JSON 字符串中天然安全,且在 URL 路径中语义清晰(`/v1/models/openai/gpt-4`),更符合 REST 风格。 - -### D2: models 表 schema 变更 - -``` -旧: id(TEXT PK, 用户自定义), provider_id, model_name(上游模型名), enabled, created_at -新: id(UUID PK, 自动生成), provider_id, model_name(上游模型名), enabled, created_at - UNIQUE(provider_id, model_name) -``` - -关键语义变化: -- `id` 从用户自定义标识符变为 UUID 内部主键(自动生成),用于管理接口 CRUD -- `model_name` 语义不变,始终存储上游供应商的模型名称,发给上游的实际值 -- 新增联合唯一约束 `UNIQUE(provider_id, model_name)` 保证同一供应商内模型不重复 - -选择保留 `id` 作为 PK 而非使用 `(provider_id, model_name)` 联合主键的原因:上游模型名可能含 `/` 等特殊字符(如 Azure OpenAI 的 deployment 路径),不适合作为管理接口的 URL 参数。`id` 为 UUID 可以避免所有特殊字符问题。 - -### D3: Models/ModelInfo 接口本地聚合 - -`GET /v1/models` 从数据库查询所有 `enabled` 的模型(JOIN providers),组装为 `CanonicalModelList`,`ID` 字段使用统一模型 ID,通过客户端协议的 adapter 编码返回。不请求上游。 - -`GET /v1/models/{provider_id}/{model_name}` 从 URL 提取统一模型 ID,解析后查询数据库,组装为 `CanonicalModelInfo` 返回。不请求上游。 - -选择纯 DB 聚合而非实时查询上游的原因: -1. 管理员通过 `/api/models` 控制哪些模型对用户可见,网关的意义在于控制可见性 -2. 响应速度快,不依赖上游可用性 -3. 符合当前架构中管理员手动配置 provider 和 model 的设计哲学 - -### D4: 跨协议响应 model 字段覆写 - -跨协议场景下,上游返回的响应经过 decode → encode 全量转换。上游响应中的 `model` 字段是原生模型名(如 `gpt-4`),需要在返回给客户端前覆写为统一模型 ID。 - -实现位置:`ConversionEngine.ConvertHttpResponse` 新增 `modelOverride string` 参数。在解码上游响应到 canonical 后、编码客户端响应前,将 `canonical.Model` 设为 `modelOverride`。流式场景同理,`CreateStreamConverter` 同样接收 `modelOverride` 参数。 - -此方案仅在跨协议转换路径使用。选择在 canonical 层面处理的原因: -1. 跨协议必须全量 decode → encode,canonical 的 Model 字段天然可覆写 -2. 不侵入各协议 adapter 的实现 -3. 与 Smart Passthrough 互补——跨协议不可保真,canonical 覆写是自然的 - -### D5: ProtocolAdapter 接口扩展 - -在 `ProtocolAdapter` 接口新增四个方法,将所有协议相关的 model 字段知识归属到 adapter: - -1. `ExtractUnifiedModelID(nativePath string) (string, error)` — 从路径中提取统一模型 ID -2. `ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)` — 从请求体中提取 model 值(所有流程复用,替代 handler 层硬编码的 `extractModelName`) -3. `RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)` — 最小化 JSON 改写请求体中的 model 字段(Smart Passthrough 请求方向) -4. `RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)` — 最小化 JSON 改写响应体中的 model 字段(Smart Passthrough 响应方向) - -拆分请求/响应方向的原因:请求体和响应体的 JSON 结构可能不同,model 字段的位置可能不同(当前 OpenAI/Anthropic 协议碰巧都在顶层 `"model"`,但未来协议不一定)。拆分后 adapter 各自独立实现,各自按 ifaceType 分派。 - -`ExtractModelName` 和两个 `Rewrite*` 方法均接收 `InterfaceType` 参数,因为不同接口类型的请求体/响应体结构可能不同,adapter 按 ifaceType 分派具体的定位和改写逻辑。 - -对于 `isModelInfoPath` 的调整:允许 suffix 中包含 `/`,因为统一模型 ID 格式为 `provider_id/model_name`。 - -将此方法放在适配器接口而非 handler 中通用实现的原因:不同协议的模型详情路径格式和请求体结构可能不同,各自拥有独立演进能力。 - -### D6: provider_id 字符集约束 - -创建供应商时校验 `id` 字段必须匹配 `^[a-zA-Z0-9_]+$`,长度 1-64。 - -选择严格限制而非仅排除 `/` 的原因:统一模型 ID 出现在 URL 路径和 JSON 中,`?`、`#`、`&`、`=` 等字符会在 URL 中引起解析问题。限制为字母数字下划线后,URL 中永远安全,不需要编码。 - -### D7: pkg/modelid 工具包 - -新增 `pkg/modelid` 包,提供: -- `ParseUnifiedModelID(id string) (providerID, modelName string, error)` — 解析 -- `FormatUnifiedModelID(providerID, modelName string) string` — 格式化 -- `ValidateProviderID(id string) error` — 校验供应商 ID -- `IsValidUnifiedModelID(id string) bool` — 校验统一模型 ID - -使用标准库 `strings.SplitN` 和 `regexp` 实现,不引入新依赖。 - -### D8: 同协议 Smart Passthrough - -当前同协议透传将请求体原样转发,跳过 decode → encode,保持参数完全保真。但统一模型 ID 要求改写 model 字段,原样透传无法满足。 - -**Smart Passthrough**:保留同协议透传的保真优势,通过 `json.RawMessage` 做最小化改写。 - -实现方式:adapter 的 `RewriteRequestModelName` 和 `RewriteResponseModelName` 方法各自解析 JSON 为 `map[string]json.RawMessage`,只替换 model 字段的 value,其余字段保留原始 bytes,不经过任何类型转换。参数保真、不丢精度、不改字段顺序。 - -各接口类型策略: -- Chat/Embedding/Rerank(同协议):Smart Passthrough — 请求改写 model(统一 ID → 上游名),响应改写 model(上游名 → 统一 ID) -- Chat/Embedding/Rerank(跨协议):全量 decode → encode + modelOverride -- Models/ModelInfo:本地数据库聚合,不请求上游 -- Passthrough(未知路径):原样透传,不改写 model - -选择让 adapter 拥有完整协议知识(而非通用 json hack)的原因: -1. 不同协议的 model 字段位置可能不同,adapter 按 InterfaceType 分派 -2. 请求和响应的 model 字段位置可能不同,拆分 RewriteRequestModelName/RewriteResponseModelName 各自独立实现 -3. adapter 内部实现 `ExtractModelName` 和两个 `Rewrite*` 方法可共享同一份"model 在哪"的定位逻辑 -4. 所有流程复用 `ExtractModelName`,同协议额外复用 `RewriteRequestModelName` + `RewriteResponseModelName` - -## Risks / Trade-offs - -- **[BREAKING CHANGE]** 代理接口 model 字段格式变更,现有客户端必须适配 → 统一 ID 格式简单直观,服务尚未上线无旧客户端 -- **[联合唯一约束]** 同一供应商下相同 model_name 不允许重复 → 这是正确的行为,语义上就不应该重复 -- **[model_name 含特殊字符]** 上游模型名可能含 `/`(如 Azure deployment 路径)→ 解析用 `SplitN("/", 2)` 安全,管理接口用 `id` 定位不受影响,代理接口中统一 ID 出现在 JSON body 和 URL 路径中均安全 -- **[流式响应覆写]** 同协议流式场景需逐 SSE chunk 调用 RewriteResponseModelName → 每个 chunk 多一次轻量 JSON 解析,用 json.RawMessage 保证开销极小 - -## Open Questions - -无。所有关键决策已在探索阶段确认。 diff --git a/openspec/changes/unified-model-id/proposal.md b/openspec/changes/unified-model-id/proposal.md deleted file mode 100644 index e7db16e..0000000 --- a/openspec/changes/unified-model-id/proposal.md +++ /dev/null @@ -1,41 +0,0 @@ -## Why - -当前网关直接透传上游供应商的原始模型名称(如 `gpt-4`、`claude-3-opus`),无法在多供应商场景下唯一标识一个模型。不同供应商可能存在同名模型,客户端无法区分应路由到哪个供应商。网关作为屏蔽供应商差异的统一入口,需要定义自有的模型标识体系,让客户端通过统一的 model ID 访问任意供应商的模型,同时拦截 `/v1/models` 等模型查询接口,聚合所有供应商的模型信息返回。 - -## What Changes - -- **BREAKING**: 引入统一模型 ID 格式 `provider_id/model_name`(如 `openai/gpt-4`),所有代理接口(Chat、Embeddings、Rerank)的 `model` 字段必须使用此格式 -- **BREAKING**: `models` 表主键 `id` 改为 UUID 自动生成(不再由用户提供),`model_name` 字段语义保持不变(存储上游供应商模型名称),新增 `UNIQUE(provider_id, model_name)` 联合唯一约束 -- **BREAKING**: `provider_id` 限制为 `[a-zA-Z0-9_]+` 字符集,禁止特殊字符 -- `GET /v1/models` 改为从数据库聚合返回所有已启用模型,不再透传到上游供应商 -- `GET /v1/models/{unified_id}` 改为从数据库查询返回模型详情,不再透传到上游供应商 -- 同协议透传改为 Smart Passthrough:通过 `json.RawMessage` 最小化改写 model 字段,保持其余参数完全保真 -- 跨协议转换路径:通过 canonical 层面 modelOverride 参数覆写响应 model 字段 -- 管理 API (`/api/models`) 请求体字段适配,响应中新增 `unified_id` 字段 -- 新增 `pkg/modelid` 工具包,提供统一模型 ID 的解析、格式化、校验 -- ProtocolAdapter 接口新增 `ExtractUnifiedModelID`、`ExtractModelName`、`RewriteRequestModelName`、`RewriteResponseModelName` 方法,协议无关地处理 model 字段 - -## Capabilities - -### New Capabilities -- `unified-model-id`: 统一模型 ID 的解析、格式化、校验工具包,以及 `provider_id` 字符集约束 - -### Modified Capabilities -- `model-management`: 模型表结构调整(id 改 UUID 自动生成、新增联合唯一约束),CRUD 接口字段变更(创建不再提供 id) -- `provider-management`: provider_id 创建时增加字符集校验(`[a-zA-Z0-9_]+`) -- `unified-proxy-handler`: 统一模型 ID 解析路由、Models/ModelInfo 接口改为本地聚合、同协议 Smart Passthrough、跨协议 modelOverride 覆写 -- `conversion-engine`: 跨协议场景下 ConvertHttpResponse 支持 model 覆写参数 -- `protocol-adapter-openai`: isModelInfoPath 适配含 `/` 路径、新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName -- `protocol-adapter-anthropic`: isModelInfoPath 适配含 `/` 路径、新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName -- `request-validation`: provider_id 字符集校验规则、模型创建校验适配 -- `database-migration`: models 表 schema 变更迁移(DROP + CREATE 重建) -- `usage-statistics`: 明确统计记录使用 providerID + modelName 的上游模型名 - -## Impact - -- **数据库**: models 表 schema 变更(DROP + CREATE 重建) -- **API 兼容性**: 代理接口 model 字段格式为 BREAKING CHANGE,需客户端适配 -- **管理 API**: `/api/models` 请求体变更(创建不再提供 id,自动生成 UUID),响应新增 unified_id 字段 -- **代码模块**: domain、repository、service、handler、conversion、adapter 层均有改动 -- **测试**: routing service、proxy handler、adapter、model handler 需要新增/更新测试 -- **前端**: 本次变更不涉及前端适配,前端后续统一适配 diff --git a/openspec/changes/unified-model-id/specs/conversion-engine/spec.md b/openspec/changes/unified-model-id/specs/conversion-engine/spec.md deleted file mode 100644 index 62ca31c..0000000 --- a/openspec/changes/unified-model-id/specs/conversion-engine/spec.md +++ /dev/null @@ -1,49 +0,0 @@ -## MODIFIED Requirements - -### Requirement: 跨协议响应转换支持 model 覆写 - -ConversionEngine SHALL 在跨协议响应转换时支持 model 字段覆写。 - -#### Scenario: ConvertHttpResponse 接收 modelOverride 参数 - -- **WHEN** 调用 `ConvertHttpResponse` 时传入 `modelOverride` 参数(跨协议场景,非空字符串) -- **THEN** SHALL 在解码上游响应到 canonical 后,将 `Model` 字段设为 `modelOverride` -- **THEN** SHALL 使用覆写后的 canonical 编码为客户端协议格式 - -#### Scenario: modelOverride 为空 - -- **WHEN** 调用 `ConvertHttpResponse` 时 `modelOverride` 为空字符串 -- **THEN** SHALL NOT 覆写 canonical 的 Model 字段,保持上游原始值 - -#### Scenario: Chat 响应 model 覆写 - -- **WHEN** 跨协议转换 Chat 类型响应且 `modelOverride` 非空 -- **THEN** `CanonicalResponse.Model` SHALL 被设为 `modelOverride` - -#### Scenario: Embedding 响应 model 覆写 - -- **WHEN** 跨协议转换 Embedding 类型响应且 `modelOverride` 非空 -- **THEN** `CanonicalEmbeddingResponse.Model` SHALL 被设为 `modelOverride` - -#### Scenario: Rerank 响应 model 覆写 - -- **WHEN** 跨协议转换 Rerank 类型响应且 `modelOverride` 非空 -- **THEN** `CanonicalRerankResponse.Model` SHALL 被设为 `modelOverride` - -### Requirement: 跨协议流式转换支持 model 覆写 - -ConversionEngine SHALL 在跨协议流式转换时支持 model 字段覆写。 - -#### Scenario: CreateStreamConverter 接收 modelOverride 参数 - -- **WHEN** 调用 `CreateStreamConverter` 时传入 `modelOverride` 参数(跨协议场景) -- **THEN** SHALL 在流式 canonical 事件中将 `Model` 字段设为 `modelOverride` - -### Requirement: TargetProvider 字段语义 - -TargetProvider 的 ModelName 字段 SHALL 存储上游供应商的模型名称(即 `model_name` 字段值),语义保持不变。 - -#### Scenario: encoder 使用 TargetProvider.ModelName - -- **WHEN** 协议适配器编码请求时 -- **THEN** SHALL 使用 `TargetProvider.ModelName` 作为发给上游的 `model` 字段值(值为路由结果中的 model_name) diff --git a/openspec/changes/unified-model-id/specs/database-migration/spec.md b/openspec/changes/unified-model-id/specs/database-migration/spec.md deleted file mode 100644 index 3d9dc2a..0000000 --- a/openspec/changes/unified-model-id/specs/database-migration/spec.md +++ /dev/null @@ -1,13 +0,0 @@ -## MODIFIED Requirements - -### Requirement: models 表 schema 变更 - -系统 SHALL 通过迁移脚本重建 models 表结构(服务未上线,无需考虑数据迁移)。 - -#### Scenario: 迁移后 models 表结构 - -- **WHEN** 执行迁移 -- **THEN** SHALL 先 DROP 已有的 models 表(无旧数据) -- **THEN** SHALL CREATE 新的 models 表,包含字段:id(TEXT PRIMARY KEY)、provider_id(TEXT NOT NULL)、model_name(TEXT NOT NULL)、enabled(INTEGER DEFAULT 1)、created_at(DATETIME) -- **THEN** SHALL 存在 UNIQUE(provider_id, model_name) 约束 -- **THEN** SHALL 存在 FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE diff --git a/openspec/changes/unified-model-id/specs/model-management/spec.md b/openspec/changes/unified-model-id/specs/model-management/spec.md deleted file mode 100644 index 606be46..0000000 --- a/openspec/changes/unified-model-id/specs/model-management/spec.md +++ /dev/null @@ -1,105 +0,0 @@ -## MODIFIED Requirements - -### Requirement: 创建模型配置 - -网关 SHALL 允许为供应商创建新的模型配置。 - -#### Scenario: 使用有效数据创建模型 - -- **WHEN** 向 `/api/models` 发送 POST 请求,携带有效的模型数据(provider_id, model_name),不提供 id 字段 -- **THEN** 网关 SHALL 自动生成 UUID 作为模型 id -- **THEN** 网关 SHALL 在数据库中创建新的模型记录 -- **THEN** 网关 SHALL 返回创建的模型,状态码为 201 -- **THEN** 模型 SHALL 默认启用 -- **THEN** 返回的模型 SHALL 包含 `unified_id` 字段,值为 `{provider_id}/{model_name}` - -#### Scenario: 使用不存在的供应商创建模型 - -- **WHEN** 向 `/api/models` 发送 POST 请求,携带不存在的 provider_id -- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) -- **THEN** 错误 SHALL 指示供应商不存在 - -#### Scenario: 创建重复模型 - -- **WHEN** 向 `/api/models` 发送 POST 请求,携带已存在的 provider_id + model_name 组合 -- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict) -- **THEN** 错误 SHALL 指示该供应商下已存在相同模型 - -### Requirement: 列出所有模型 - -网关 SHALL 允许获取所有模型配置。 - -#### Scenario: 成功列出模型 - -- **WHEN** 向 `/api/models` 发送 GET 请求 -- **THEN** 网关 SHALL 返回所有模型的列表 -- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, unified_id, enabled, created_at - -**变更说明:** 响应新增 unified_id 字段,移除旧语义的 id 自定义输入。 - -### Requirement: 按供应商列出模型 - -网关 SHALL 允许获取特定供应商的模型。 - -#### Scenario: 列出存在供应商的模型 - -- **WHEN** 向 `/api/models?provider_id=` 发送 GET 请求 -- **THEN** 网关 SHALL 返回指定供应商的模型列表 -- **THEN** 每个模型 SHALL 包含 unified_id 字段 - -### Requirement: 更新模型配置 - -网关 SHALL 允许更新现有模型配置。 - -#### Scenario: 使用有效数据更新模型 - -- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据 -- **THEN** 网关 SHALL 更新数据库中的模型记录 -- **THEN** 网关 SHALL 返回更新后的模型 -- **THEN** 返回的模型 SHALL 包含更新后的 unified_id - -#### Scenario: 更新模型为重复组合 - -- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,更新 provider_id 或 model_name 导致与已有记录重复 -- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict) - -### Requirement: 删除模型配置 - -网关 SHALL 允许删除模型配置。 - -#### Scenario: 删除存在的模型 - -- **WHEN** 向 `/api/models/:id` 发送 DELETE 请求,携带有效的模型 ID -- **THEN** 网关 SHALL 删除模型记录 -- **THEN** 网关 SHALL 返回状态码 204 (No Content) - -### Requirement: 使用 service 层处理业务逻辑 - -Handler SHALL 通过 ModelService 处理业务逻辑。 - -#### Scenario: 调用 service 方法 - -- **WHEN** handler 收到请求 -- **THEN** SHALL 调用对应的 ModelService 方法(Create、Get、List、Update、Delete) -- **THEN** SHALL 使用 domain.Model 类型 -- **THEN** Create 时 SHALL 调用 `uuid.New()` 生成 id - -#### Scenario: 供应商验证和唯一性校验 - -- **WHEN** 创建或更新模型 -- **THEN** SHALL 在 service 层验证供应商存在 -- **THEN** SHALL 在 service 层验证 provider_id + model_name 联合唯一 - -### Requirement: 使用 repository 层访问数据 - -Service SHALL 通过 ModelRepository 访问数据。 - -#### Scenario: 联合查询 - -- **WHEN** service 需要按 provider 和 model_name 查询模型 -- **THEN** SHALL 调用 `FindByProviderAndModelName(providerID, modelName)` 方法 - -#### Scenario: 查询所有启用模型 - -- **WHEN** proxy handler 需要聚合模型列表 -- **THEN** SHALL 调用 `ListEnabled()` 方法,返回所有 enabled 的模型(关联 enabled 的供应商) diff --git a/openspec/changes/unified-model-id/specs/protocol-adapter-anthropic/spec.md b/openspec/changes/unified-model-id/specs/protocol-adapter-anthropic/spec.md deleted file mode 100644 index aad01fd..0000000 --- a/openspec/changes/unified-model-id/specs/protocol-adapter-anthropic/spec.md +++ /dev/null @@ -1,71 +0,0 @@ -## MODIFIED Requirements - -### Requirement: 模型详情路径识别 - -Anthropic 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。 - -#### Scenario: 含斜杠的统一模型 ID 路径 - -- **WHEN** 路径为 `/v1/models/anthropic/claude-3-opus` -- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` - -#### Scenario: 含多段斜杠的统一模型 ID 路径 - -- **WHEN** 路径为 `/v1/models/azure/deployments/gpt-4` -- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` - -#### Scenario: 模型列表路径不受影响 - -- **WHEN** 路径为 `/v1/models` -- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels` - -### Requirement: 提取统一模型 ID - -Anthropic 适配器 SHALL 从路径中提取统一模型 ID。 - -#### Scenario: 标准路径提取 - -- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/anthropic/claude-3-opus")` -- **THEN** SHALL 返回 `"anthropic/claude-3-opus"` - -#### Scenario: 复杂路径提取 - -- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")` -- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"` - -#### Scenario: 非模型详情路径 - -- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")` -- **THEN** SHALL 返回错误 - -### Requirement: 从请求体提取 model - -Anthropic 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。 - -#### Scenario: Chat 请求提取 model - -- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"anthropic/claude-3-opus","messages":[...]}` -- **THEN** SHALL 返回 `"anthropic/claude-3-opus"` - -#### Scenario: 无 model 字段 - -- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段 -- **THEN** SHALL 返回空字符串,不返回错误 - -### Requirement: 最小化改写请求体 model 字段 - -Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。 - -#### Scenario: 请求体 model 改写(统一 ID → 上游名) - -- **WHEN** 调用 `RewriteRequestModelName(body, "claude-3-opus-20240229", InterfaceTypeChat)` -- **THEN** SHALL 将请求体中 model 字段替换为 `"claude-3-opus-20240229"`,其余字段原样保留 - -### Requirement: 最小化改写响应体 model 字段 - -Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。 - -#### Scenario: 响应体 model 改写(上游名 → 统一 ID) - -- **WHEN** 调用 `RewriteResponseModelName(body, "anthropic/claude-3-opus", InterfaceTypeChat)` -- **THEN** SHALL 将响应体中 model 字段替换为 `"anthropic/claude-3-opus"`,其余字段原样保留 diff --git a/openspec/changes/unified-model-id/specs/protocol-adapter-openai/spec.md b/openspec/changes/unified-model-id/specs/protocol-adapter-openai/spec.md deleted file mode 100644 index 8b54356..0000000 --- a/openspec/changes/unified-model-id/specs/protocol-adapter-openai/spec.md +++ /dev/null @@ -1,89 +0,0 @@ -## MODIFIED Requirements - -### Requirement: 模型详情路径识别 - -OpenAI 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。 - -#### Scenario: 含斜杠的统一模型 ID 路径 - -- **WHEN** 路径为 `/v1/models/openai/gpt-4` -- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` - -#### Scenario: 含多段斜杠的统一模型 ID 路径 - -- **WHEN** 路径为 `/v1/models/azure/accounts/org-123/models/gpt-4` -- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` - -#### Scenario: 模型列表路径不受影响 - -- **WHEN** 路径为 `/v1/models` -- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels` - -### Requirement: 提取统一模型 ID - -OpenAI 适配器 SHALL 从路径中提取统一模型 ID。 - -#### Scenario: 标准路径提取 - -- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/openai/gpt-4")` -- **THEN** SHALL 返回 `"openai/gpt-4"` - -#### Scenario: 复杂路径提取 - -- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")` -- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"` - -#### Scenario: 非模型详情路径 - -- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")` -- **THEN** SHALL 返回错误 - -### Requirement: 从请求体提取 model - -OpenAI 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。 - -#### Scenario: Chat 请求提取 model - -- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...]}` -- **THEN** SHALL 返回 `"openai/gpt-4"` - -#### Scenario: Embedding 请求提取 model - -- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeEmbeddings)`,body 为 `{"model":"openai/text-embedding-3","input":"text"}` -- **THEN** SHALL 返回 `"openai/text-embedding-3"` - -#### Scenario: 无 model 字段 - -- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段 -- **THEN** SHALL 返回空字符串,不返回错误 - -### Requirement: 最小化改写请求体 model 字段 - -OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。 - -#### Scenario: 请求体 model 改写(统一 ID → 上游名) - -- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...],"some_param":"value"}` -- **THEN** SHALL 返回 `{"model":"gpt-4","messages":[...],"some_param":"value"}` -- **THEN** 除 model 外的字段 SHALL 保持原始 bytes 不变 - -#### Scenario: 不同 InterfaceType 的请求改写 - -- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeEmbeddings)` -- **THEN** SHALL 按 Embedding 接口的请求体 model 字段位置进行改写 -- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeRerank)` -- **THEN** SHALL 按 Rerank 接口的请求体 model 字段位置进行改写 - -### Requirement: 最小化改写响应体 model 字段 - -OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。 - -#### Scenario: 响应体 model 改写(上游名 → 统一 ID) - -- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeChat)`,body 为上游 Chat 响应 -- **THEN** SHALL 将 model 字段替换为 `"openai/gpt-4"`,其余字段原样保留 - -#### Scenario: 不同 InterfaceType 的响应改写 - -- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeEmbeddings)` -- **THEN** SHALL 按 Embedding 接口的响应体 model 字段位置进行改写 diff --git a/openspec/changes/unified-model-id/specs/provider-management/spec.md b/openspec/changes/unified-model-id/specs/provider-management/spec.md deleted file mode 100644 index ae3b265..0000000 --- a/openspec/changes/unified-model-id/specs/provider-management/spec.md +++ /dev/null @@ -1,35 +0,0 @@ -## MODIFIED Requirements - -### Requirement: 创建供应商配置 - -网关 SHALL 允许通过管理 API 创建新的供应商配置。 - -#### Scenario: 使用有效数据创建供应商 - -- **WHEN** 向 `/api/providers` 发送 POST 请求,携带有效的供应商数据(id, name, api_key, base_url, protocol) -- **THEN** 网关 SHALL 在数据库中创建新的供应商记录 -- **THEN** 网关 SHALL 返回创建的供应商,状态码为 201 -- **THEN** 供应商 SHALL 默认启用 -- **THEN** protocol 字段 SHALL 默认为 "openai" - -#### Scenario: 使用重复 ID 创建供应商 - -- **WHEN** 向 `/api/providers` 发送 POST 请求,携带已存在的 ID -- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict) - -#### Scenario: 创建供应商时缺少必需字段 - -- **WHEN** 向 `/api/providers` 发送 POST 请求,缺少必需字段(id, name, api_key 或 base_url) -- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) -- **THEN** 错误 SHALL 指示缺少哪些字段 - -#### Scenario: 创建供应商时 ID 包含非法字符 - -- **WHEN** 向 `/api/providers` 发送 POST 请求,id 包含非 `[a-zA-Z0-9_]` 字符 -- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) -- **THEN** 错误 SHALL 指示 id 仅允许字母、数字、下划线 - -#### Scenario: 创建供应商时 ID 过长 - -- **WHEN** 向 `/api/providers` 发送 POST 请求,id 长度超过 64 -- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) diff --git a/openspec/changes/unified-model-id/specs/request-validation/spec.md b/openspec/changes/unified-model-id/specs/request-validation/spec.md deleted file mode 100644 index d652bf4..0000000 --- a/openspec/changes/unified-model-id/specs/request-validation/spec.md +++ /dev/null @@ -1,39 +0,0 @@ -## MODIFIED Requirements - -### Requirement: 供应商 ID 校验 - -创建供应商时,SHALL 对 `id` 字段进行字符集校验。 - -#### Scenario: 合法字符集 - -- **WHEN** 创建供应商,id 仅包含 `[a-zA-Z0-9_]` 字符 -- **THEN** SHALL 校验通过 - -#### Scenario: 非法字符 - -- **WHEN** 创建供应商,id 包含 `-`、`.`、`/`、空格、中文等非 `[a-zA-Z0-9_]` 字符 -- **THEN** SHALL 返回 400 错误 - -#### Scenario: 长度限制 - -- **WHEN** 创建供应商,id 长度超过 64 -- **THEN** SHALL 返回 400 错误 - -### Requirement: 模型创建校验 - -创建模型时,SHALL 对 `provider_id` + `model_name` 进行联合唯一性校验。 - -#### Scenario: 正常创建 - -- **WHEN** 创建模型,provider_id 存在且 provider_id + model_name 组合唯一 -- **THEN** SHALL 校验通过 - -#### Scenario: 联合唯一冲突 - -- **WHEN** 创建模型,provider_id + model_name 组合已存在 -- **THEN** SHALL 返回 409 错误 - -#### Scenario: model_name 为空 - -- **WHEN** 创建模型,未提供 model_name -- **THEN** SHALL 返回 400 错误 diff --git a/openspec/changes/unified-model-id/specs/unified-proxy-handler/spec.md b/openspec/changes/unified-model-id/specs/unified-proxy-handler/spec.md deleted file mode 100644 index 4fb1f2a..0000000 --- a/openspec/changes/unified-model-id/specs/unified-proxy-handler/spec.md +++ /dev/null @@ -1,123 +0,0 @@ -## MODIFIED Requirements - -### Requirement: 代理请求路由 - -ProxyHandler SHALL 使用统一模型 ID 路由所有代理请求。 - -#### Scenario: 提取统一模型 ID - -- **WHEN** 收到 Chat、Embeddings 或 Rerank 接口的 POST 请求(含请求体) -- **THEN** SHALL 调用客户端协议 adapter 的 `ExtractModelName(body, ifaceType)` 提取 model 值 -- **THEN** SHALL 调用 `ParseUnifiedModelID` 解析得到 providerID 和 modelName -- **THEN** SHALL 调用 `RoutingService.RouteByModelName(providerID, modelName)` 路由 - -#### Scenario: GET 请求或无请求体 - -- **WHEN** 收到 GET 请求或请求体为空 -- **THEN** SHALL 返回错误响应,状态码为 400,提示缺少 model 字段 - -#### Scenario: 无效的统一模型 ID - -- **WHEN** 请求体中 `model` 字段不是有效的统一模型 ID 格式 -- **THEN** SHALL 返回错误响应,状态码为 400 - -#### Scenario: 模型不存在 - -- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的 provider_id + model_name 组合 -- **THEN** SHALL 返回错误响应,状态码为 404 - -#### Scenario: 模型已禁用 - -- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false -- **THEN** SHALL 返回错误响应,状态码为 404 - -#### Scenario: 供应商已禁用 - -- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false -- **THEN** SHALL 返回错误响应,状态码为 404 - -### Requirement: 同协议 Smart Passthrough - -当客户端协议与供应商协议相同时,ProxyHandler SHALL 使用 Smart Passthrough 处理 Chat、Embedding、Rerank 请求。 - -#### Scenario: 同协议非流式请求 - -- **WHEN** 客户端协议 == 供应商协议,且为非流式请求 -- **THEN** SHALL 调用 adapter 的 `RewriteRequestModelName(body, modelName, ifaceType)` 将请求体中 model 从统一 ID 改写为上游模型名 -- **THEN** SHALL 构建 URL 和 Headers(同当前透传逻辑) -- **THEN** SHALL 发送改写后的请求体到上游 -- **THEN** SHALL 调用 adapter 的 `RewriteResponseModelName(resp.Body, unifiedModelID, ifaceType)` 将响应中 model 从上游名改写为统一 ID -- **THEN** SHALL NOT 对 body 做全量 decode → encode,保持未改写字段的原始 bytes - -#### Scenario: 同协议流式请求 - -- **WHEN** 客户端协议 == 供应商协议,且为流式请求 -- **THEN** SHALL 对请求体做 `RewriteRequestModelName` 改写 model 字段 -- **THEN** SHALL 逐 SSE chunk 调用 `RewriteResponseModelName` 改写响应中 model 字段 -- **THEN** SHALL NOT 对 chunk 做全量 decode → encode - -#### Scenario: Smart Passthrough 保真性 - -- **WHEN** 客户端发送含未知参数的请求(如 `{"model":"openai/gpt-4","some_new_param":"value"}`) -- **THEN** 上游 SHALL 收到 `{"model":"gpt-4","some_new_param":"value"}` -- **THEN** `some_new_param` SHALL 保持原始值不变,不丢失、不改变类型 - -### Requirement: 跨协议完整转换 - -当客户端协议与供应商协议不同时,ProxyHandler SHALL 使用全量转换路径。 - -#### Scenario: 跨协议非流式请求 - -- **WHEN** 客户端协议 != 供应商协议 -- **THEN** SHALL 走 `ConvertHttpRequest` 全量转换,encoder 中 provider.ModelName 覆盖 model -- **THEN** SHALL 走 `ConvertHttpResponse` 全量转换,modelOverride 参数覆写 canonical.Model - -#### Scenario: 跨协议流式请求 - -- **WHEN** 客户端协议 != 供应商协议,且为流式请求 -- **THEN** SHALL 走 `CreateStreamConverter` 全量转换,modelOverride 参数覆写流式 canonical 事件中的 Model - -### Requirement: 模型列表本地聚合 - -ProxyHandler SHALL 从数据库聚合返回模型列表,不再透传上游。 - -#### Scenario: GET /v1/models - -- **WHEN** 收到 `GET /{protocol}/v1/models` 请求 -- **THEN** SHALL 从数据库查询所有 enabled 的模型(关联 enabled 的供应商) -- **THEN** SHALL 组装 `CanonicalModelList`,每个模型的 ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id -- **THEN** SHALL 使用客户端协议的 adapter 编码响应 -- **THEN** SHALL NOT 请求上游供应商 - -#### Scenario: 无可用模型 - -- **WHEN** 数据库中没有 enabled 的模型 -- **THEN** SHALL 返回空列表 - -### Requirement: 模型详情本地查询 - -ProxyHandler SHALL 从数据库查询返回模型详情,不再透传上游。 - -#### Scenario: GET /v1/models/{unified_id} - -- **WHEN** 收到 `GET /{protocol}/v1/models/{provider_id}/{model_name}` 请求 -- **THEN** SHALL 调用 adapter 的 `ExtractUnifiedModelID` 提取统一模型 ID -- **THEN** SHALL 解析统一模型 ID 得到 providerID 和 modelName -- **THEN** SHALL 从数据库查询对应的模型和供应商 -- **THEN** SHALL 组装 `CanonicalModelInfo`,ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id -- **THEN** SHALL 使用客户端协议的 adapter 编码响应 -- **THEN** SHALL NOT 请求上游供应商 - -#### Scenario: 模型详情不存在 - -- **WHEN** 统一模型 ID 对应的模型不存在或已禁用 -- **THEN** SHALL 返回错误响应,状态码为 404 - -### Requirement: 统计记录 - -ProxyHandler SHALL 使用 providerID 和 modelName 记录使用统计。 - -#### Scenario: 异步记录统计 - -- **WHEN** 代理请求成功完成 -- **THEN** SHALL 异步调用 `StatsService.Record(providerID, modelName)` diff --git a/openspec/changes/unified-model-id/specs/usage-statistics/spec.md b/openspec/changes/unified-model-id/specs/usage-statistics/spec.md deleted file mode 100644 index 5685682..0000000 --- a/openspec/changes/unified-model-id/specs/usage-statistics/spec.md +++ /dev/null @@ -1,16 +0,0 @@ -## ADDED Requirements - -### Requirement: 使用统计记录统一模型标识 - -系统 SHALL 使用 providerID 和 modelName(上游模型名)记录使用统计。 - -#### Scenario: 代理请求统计记录 - -- **WHEN** 代理请求成功完成 -- **THEN** SHALL 记录 provider_id 和 model_name 到 usage_stats 表(参数来自路由结果) -- **THEN** SHALL 异步执行,不阻塞响应 - -#### Scenario: 查询统计 - -- **WHEN** 查询统计数据 -- **THEN** 支持按 provider_id 和 model_name 过滤 diff --git a/openspec/changes/unified-model-id/tasks.md b/openspec/changes/unified-model-id/tasks.md deleted file mode 100644 index 02613f4..0000000 --- a/openspec/changes/unified-model-id/tasks.md +++ /dev/null @@ -1,53 +0,0 @@ -## 1. 数据库迁移 - -- [ ] 1.1 新增迁移脚本:DROP 旧 models 表 + CREATE 新 models 表(id UUID PK, provider_id, model_name, enabled, created_at),UNIQUE(provider_id, model_name) -- [ ] 1.2 更新 config/models.go:Model 结构体适配(id 改为 UUID 自动生成,model_name 保持不变) -- [ ] 1.3 编写迁移脚本测试 - -## 2. 统一模型 ID 工具包 - -- [ ] 2.1 新增 pkg/modelid/model_id.go:实现 ParseUnifiedModelID、FormatUnifiedModelID、ValidateProviderID、IsValidUnifiedModelID -- [ ] 2.2 新增 pkg/modelid/model_id_test.go:覆盖标准格式、含斜杠 model_name、空字符串、非法字符等边界情况 - -## 3. Domain 层适配 - -- [ ] 3.1 修改 domain/model.go:Model 结构体字段适配,新增 UnifiedModelID() 方法 -- [ ] 3.2 修改 domain/route.go:RouteResult 适配新字段 - -## 4. Repository 层适配 - -- [ ] 4.1 修改 repository/model_repo.go:接口变更 — GetByModelName 改为 FindByProviderAndModelName,新增 ListEnabled -- [ ] 4.2 修改 repository/model_repo_impl.go:实现 FindByProviderAndModelName(WHERE provider_id=? AND model_name=?)、ListEnabled(JOIN providers WHERE enabled) -- [ ] 4.3 编写 repository 层测试 - -## 5. Service 层适配 - -- [ ] 5.1 修改 service/routing_service.go:Route 接口改为 RouteByModelName(providerID, modelName string) -- [ ] 5.2 修改 service/routing_service_impl.go:调用 FindByProviderAndModelName 替代 GetByModelName -- [ ] 5.3 修改 service/model_service.go:Create 生成 UUID、新增联合唯一校验方法 -- [ ] 5.4 修改 service/model_service_impl.go:实现联合唯一校验、UUID 生成 -- [ ] 5.5 修改 service/provider_service_impl.go:Create 时调用 ValidateProviderID 校验 ID 字符集 -- [ ] 5.6 编写 service 层测试 - -## 6. Conversion 层适配 - -- [ ] 6.1 修改 conversion/adapter.go:ProtocolAdapter 接口新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName 四个方法 -- [ ] 6.2 修改 conversion/engine.go:ConvertHttpResponse 新增 modelOverride 参数(跨协议场景),各 convert*ResponseBody 中覆写 canonical Model;CreateStreamConverter 新增 modelOverride 参数 -- [ ] 6.3 修改 conversion/openai/adapter.go:实现 ExtractUnifiedModelID、ExtractModelName(按 ifaceType 提取 model)、RewriteRequestModelName 和 RewriteResponseModelName(json.RawMessage 最小化改写,按 ifaceType 定位 model 字段,请求/响应独立实现),修改 isModelInfoPath 允许 suffix 含 "/" -- [ ] 6.4 修改 conversion/anthropic/adapter.go:实现 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName,修改 isModelInfoPath 允许 suffix 含 "/" -- [ ] 6.5 编写 conversion 层测试:ExtractUnifiedModelID、ExtractModelName 各 ifaceType、RewriteRequestModelName/RewriteResponseModelName 保真性(含未知参数不丢失)、isModelInfoPath 含斜杠路径、modelOverride 覆写 - -## 7. Handler 层改造 - -- [ ] 7.1 修改 handler/proxy_handler.go:HandleProxy 按接口类型分发 — Models/ModelInfo 本地聚合;Chat/Embed/Rerank 用 adapter.ExtractModelName 提取统一 ID 路由,同协议走 Smart Passthrough(adapter.RewriteRequestModelName 改写请求、adapter.RewriteResponseModelName 改写响应),跨协议走全量转换(modelOverride);删除 forwardPassthrough 和硬编码的 extractModelName -- [ ] 7.2 修改 handler/model_handler.go:请求体字段适配(移除 id 输入、保留 provider_id 和 model_name),响应新增 unified_id,Create 使用 UUID -- [ ] 7.3 修改 handler/provider_handler.go:CreateProvider 校验 ID 字符集 -- [ ] 7.4 编写 handler 层测试:统一模型 ID 路由、同协议 Smart Passthrough 保真性、跨协议 modelOverride、Models 聚合、ModelInfo 查询、流式场景 model 覆写、provider ID 校验 - -## 8. 路由注册适配 - -- [ ] 8.1 修改 cmd/server/main.go:setupRoutes 适配 handler 签名变更,传递新增依赖 - -## 9. 文档更新 - -- [ ] 9.1 按需更新 README.md:同步 models 表结构、API 接口字段、统一模型 ID 格式、Smart Passthrough 策略等变更说明 diff --git a/openspec/specs/conversion-engine/spec.md b/openspec/specs/conversion-engine/spec.md index 7ec5df0..2673310 100644 --- a/openspec/specs/conversion-engine/spec.md +++ b/openspec/specs/conversion-engine/spec.md @@ -277,4 +277,51 @@ ErrorCode SHALL 包含:INVALID_INPUT、MISSING_REQUIRED_FIELD、INCOMPATIBLE_F - **WHEN** Adapter 调用 buildHeaders(provider) - **THEN** SHALL 从 provider.api_key 提取认证信息 - **THEN** SHALL 从 provider.adapter_config 提取协议专属配置 -- **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段 \ No newline at end of file +- **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段 +### Requirement: 跨协议响应转换支持 model 覆写 + +ConversionEngine SHALL 在跨协议响应转换时支持 model 字段覆写。 + +#### Scenario: ConvertHttpResponse 接收 modelOverride 参数 + +- **WHEN** 调用 `ConvertHttpResponse` 时传入 `modelOverride` 参数(跨协议场景,非空字符串) +- **THEN** SHALL 在解码上游响应到 canonical 后,将 `Model` 字段设为 `modelOverride` +- **THEN** SHALL 使用覆写后的 canonical 编码为客户端协议格式 + +#### Scenario: modelOverride 为空 + +- **WHEN** 调用 `ConvertHttpResponse` 时 `modelOverride` 为空字符串 +- **THEN** SHALL NOT 覆写 canonical 的 Model 字段,保持上游原始值 + +#### Scenario: Chat 响应 model 覆写 + +- **WHEN** 跨协议转换 Chat 类型响应且 `modelOverride` 非空 +- **THEN** `CanonicalResponse.Model` SHALL 被设为 `modelOverride` + +#### Scenario: Embedding 响应 model 覆写 + +- **WHEN** 跨协议转换 Embedding 类型响应且 `modelOverride` 非空 +- **THEN** `CanonicalEmbeddingResponse.Model` SHALL 被设为 `modelOverride` + +#### Scenario: Rerank 响应 model 覆写 + +- **WHEN** 跨协议转换 Rerank 类型响应且 `modelOverride` 非空 +- **THEN** `CanonicalRerankResponse.Model` SHALL 被设为 `modelOverride` + +### Requirement: 跨协议流式转换支持 model 覆写 + +ConversionEngine SHALL 在跨协议流式转换时支持 model 字段覆写。 + +#### Scenario: CreateStreamConverter 接收 modelOverride 参数 + +- **WHEN** 调用 `CreateStreamConverter` 时传入 `modelOverride` 参数(跨协议场景) +- **THEN** SHALL 在流式 canonical 事件中将 `Model` 字段设为 `modelOverride` + +### Requirement: TargetProvider 字段语义 + +TargetProvider 的 ModelName 字段 SHALL 存储上游供应商的模型名称(即 `model_name` 字段值),语义保持不变。 + +#### Scenario: encoder 使用 TargetProvider.ModelName + +- **WHEN** 协议适配器编码请求时 +- **THEN** SHALL 使用 `TargetProvider.ModelName` 作为发给上游的 `model` 字段值(值为路由结果中的 model_name) diff --git a/openspec/specs/database-migration/spec.md b/openspec/specs/database-migration/spec.md index e6286e2..733b4d1 100644 --- a/openspec/specs/database-migration/spec.md +++ b/openspec/specs/database-migration/spec.md @@ -28,9 +28,10 @@ #### Scenario: 初始迁移文件 - **WHEN** 创建初始迁移 -- **THEN** SHALL 创建 001_initial_schema.sql +- **THEN** SHALL 创建单个初始迁移文件(如 `20260421000001_initial_schema.sql`) - **THEN** SHALL 包含 providers、models、usage_stats 表的创建语句 - **THEN** SHALL 包含外键约束 +- **THEN** SHALL 包含索引创建语句 #### Scenario: Up 迁移 @@ -42,25 +43,19 @@ #### Scenario: Down 迁移 - **WHEN** 执行 down 迁移 -- **THEN** SHALL 删除所有表 +- **THEN** SHALL 删除所有表和索引 - **THEN** SHALL 按正确顺序删除(避免外键约束错误) -### Requirement: 添加索引迁移 +### Requirement: models 表 schema 变更 -系统 SHALL 创建索引迁移。 +系统 SHALL 在初始迁移脚本中直接创建新的 models 表结构(服务未上线,无需考虑数据迁移,迁移脚本已合并为单个初始迁移文件)。 -#### Scenario: 索引迁移文件 +#### Scenario: 初始迁移 models 表结构 -- **WHEN** 创建索引迁移 -- **THEN** SHALL 创建 002_add_indexes.sql -- **THEN** SHALL 为常用查询字段添加索引 - -#### Scenario: 索引定义 - -- **WHEN** 添加索引 -- **THEN** SHALL 为 models(provider_id) 添加索引 -- **THEN** SHALL 为 models(model_name) 添加索引 -- **THEN** SHALL 为 usage_stats(provider_id, model_name, date) 添加复合索引 +- **WHEN** 执行迁移 +- **THEN** SHALL CREATE models 表,包含字段:id(TEXT PRIMARY KEY,存储 UUID 字符串)、provider_id(TEXT NOT NULL)、model_name(TEXT NOT NULL)、enabled(INTEGER DEFAULT 1)、created_at(DATETIME) +- **THEN** SHALL 存在 UNIQUE(provider_id, model_name) 约束 +- **THEN** SHALL 存在 FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE ### Requirement: 迁移命令集成 diff --git a/openspec/specs/error-responses/spec.md b/openspec/specs/error-responses/spec.md new file mode 100644 index 0000000..cc45f46 --- /dev/null +++ b/openspec/specs/error-responses/spec.md @@ -0,0 +1,209 @@ +# Error Responses + +## Purpose + +定义系统统一的错误响应格式和各类错误场景,确保客户端能够一致地处理错误。 + +## Requirements + +### Requirement: 统一错误响应格式 + +系统 SHALL 使用统一的错误响应格式。 + +#### Scenario: 标准错误格式 + +- **WHEN** 返回错误响应 +- **THEN** SHALL 使用以下 JSON 格式: + ```json + { + "error": "错误描述", + "code": "ERROR_CODE" + } + ``` +- **THEN** `error` 字段 SHALL 包含人类可读的错误描述 +- **THEN** `code` 字段 SHALL 包含机器可读的错误代码(可选) + +### Requirement: provider_id 校验错误 + +系统 SHALL 对 provider_id 校验错误返回明确的错误信息。 + +#### Scenario: provider_id 包含非法字符 + +- **WHEN** 创建或更新供应商时,provider_id 包含非 `[a-zA-Z0-9_]` 字符 +- **THEN** SHALL 返回 HTTP 400 Bad Request +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "供应商 ID 仅允许字母、数字、下划线", + "code": "INVALID_PROVIDER_ID" + } + ``` + +#### Scenario: provider_id 长度超限 + +- **WHEN** 创建或更新供应商时,provider_id 长度超过 64 +- **THEN** SHALL 返回 HTTP 400 Bad Request +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "供应商 ID 长度不能超过 64 个字符", + "code": "INVALID_PROVIDER_ID" + } + ``` + +### Requirement: 联合唯一约束冲突错误 + +系统 SHALL 对联合唯一约束冲突返回明确的错误信息。 + +#### Scenario: 创建模型时 provider_id + model_name 组合已存在 + +- **WHEN** 创建模型时,provider_id + model_name 组合已存在 +- **THEN** SHALL 返回 HTTP 409 Conflict +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "同一供应商下模型名称已存在", + "code": "duplicate_model" + } + ``` + +#### Scenario: 更新模型时导致 provider_id + model_name 组合冲突 + +- **WHEN** 更新模型时,修改 provider_id 或 model_name 导致与已有记录冲突 +- **THEN** SHALL 返回 HTTP 409 Conflict +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "同一供应商下模型名称已存在", + "code": "duplicate_model" + } + ``` + +### Requirement: 资源不存在错误 + +系统 SHALL 对资源不存在返回明确的错误信息。 + +#### Scenario: 模型不存在 + +- **WHEN** 查询或操作不存在的模型 +- **THEN** SHALL 返回 HTTP 404 Not Found +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "模型未找到" + } + ``` + +#### Scenario: 供应商不存在 + +- **WHEN** 创建模型时指定的供应商不存在 +- **THEN** SHALL 返回 HTTP 400 Bad Request +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "供应商不存在" + } + ``` + +### Requirement: 统一模型 ID 格式错误 + +系统 SHALL 对统一模型 ID 格式错误返回明确的错误信息。 + +#### Scenario: 统一模型 ID 格式无效 + +- **WHEN** 代理请求中的 model 字段不是有效的统一模型 ID 格式 +- **THEN** 请求 SHALL 走 forwardPassthrough 透传到上游(兼容未适配的客户端) +- **THEN** 不返回错误,保持与上游的兼容性 + +**设计理由:** +- 统一模型 ID 是 BREAKING CHANGE,部分旧客户端可能仍使用原始模型名 +- 透传策略允许上游自行判断并返回错误(如 404 model not found) +- 网关作为透明代理,不应拦截所有格式非法的请求 + +#### Scenario: 统一模型 ID 格式有效但对应模型不存在 + +- **WHEN** 代理请求中的 model 字段是有效的统一模型 ID 格式(含 `/`),但数据库中找不到对应的模型 +- **THEN** SHALL 返回 HTTP 404 Not Found +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "模型未找到" + } + ``` + +#### Scenario: 统一模型 ID 对应的模型不存在 + +- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的模型 +- **THEN** SHALL 返回 HTTP 404 Not Found +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "模型未找到" + } + ``` + +#### Scenario: 统一模型 ID 对应的模型已禁用 + +- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false +- **THEN** SHALL 返回 HTTP 404 Not Found +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "模型未找到" + } + ``` + +#### Scenario: 统一模型 ID 对应的供应商已禁用 + +- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false +- **THEN** SHALL 返回 HTTP 404 Not Found +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "模型未找到" + } + ``` + +### Requirement: JSON 格式错误 + +系统 SHALL 对请求体 JSON 格式错误返回明确的错误信息。 + +#### Scenario: 请求体 JSON 格式错误 + +- **WHEN** 代理请求的请求体不是有效的 JSON 格式 +- **THEN** SHALL 返回 HTTP 400 Bad Request +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "请求体 JSON 格式错误", + "code": "INVALID_JSON" + } + ``` + +#### Scenario: Smart Passthrough 时请求体 JSON 格式错误 + +- **WHEN** 同协议 Smart Passthrough 场景下,请求体 JSON 格式不正确 +- **THEN** SHALL 返回 HTTP 400 Bad Request +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "请求体 JSON 格式错误", + "code": "INVALID_JSON" + } + ``` + +### Requirement: 不可变字段错误 + +系统 SHALL 对尝试修改不可变字段返回明确的错误信息。 + +#### Scenario: 尝试修改供应商 ID + +- **WHEN** 更新供应商时,请求体中包含 `id` 字段 +- **THEN** SHALL 返回 HTTP 400 Bad Request +- **THEN** SHALL 返回以下 JSON 格式: + ```json + { + "error": "供应商 ID 不允许修改", + "code": "IMMUTABLE_FIELD" + } + ``` diff --git a/openspec/specs/model-management/spec.md b/openspec/specs/model-management/spec.md index 474e49a..12cb883 100644 --- a/openspec/specs/model-management/spec.md +++ b/openspec/specs/model-management/spec.md @@ -10,10 +10,12 @@ #### Scenario: 使用有效数据创建模型 -- **WHEN** 向 `/api/models` 发送 POST 请求,携带有效的模型数据(id, provider_id, model_name) +- **WHEN** 向 `/api/models` 发送 POST 请求,携带有效的模型数据(provider_id, model_name),不提供 id 字段 +- **THEN** 网关 SHALL 自动生成 UUID 作为模型 id - **THEN** 网关 SHALL 在数据库中创建新的模型记录 - **THEN** 网关 SHALL 返回创建的模型,状态码为 201 - **THEN** 模型 SHALL 默认启用 +- **THEN** 返回的模型 SHALL 包含 `unified_id` 字段,值为 `{provider_id}/{model_name}` #### Scenario: 使用不存在的供应商创建模型 @@ -21,7 +23,11 @@ - **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) - **THEN** 错误 SHALL 指示供应商不存在 -**变更说明:** handler 通过 ModelService 调用,数据访问通过 ModelRepository 和 ProviderRepository。API 接口保持不变。 +#### Scenario: 创建重复模型 + +- **WHEN** 向 `/api/models` 发送 POST 请求,携带已存在的 provider_id + model_name 组合 +- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict) +- **THEN** 错误 SHALL 指示同一供应商下模型名称已存在 ### Requirement: 列出所有模型 @@ -31,9 +37,7 @@ - **WHEN** 向 `/api/models` 发送 GET 请求 - **THEN** 网关 SHALL 返回所有模型的列表 -- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, enabled, created_at - -**变更说明:** 数据访问从 config 包迁移到 ModelRepository。API 接口保持不变。 +- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, unified_id, enabled, created_at ### Requirement: 按供应商列出模型 @@ -43,8 +47,7 @@ - **WHEN** 向 `/api/models?provider_id=` 发送 GET 请求 - **THEN** 网关 SHALL 返回指定供应商的模型列表 - -**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。 +- **THEN** 每个模型 SHALL 包含 unified_id 字段 ### Requirement: 更新模型配置 @@ -55,14 +58,12 @@ - **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据 - **THEN** 网关 SHALL 更新数据库中的模型记录 - **THEN** 网关 SHALL 返回更新后的模型 +- **THEN** 返回的模型 SHALL 包含更新后的 unified_id -#### Scenario: 更新模型供应商 +#### Scenario: 更新模型为重复组合 -- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带新的 provider_id -- **THEN** 网关 SHALL 验证新供应商是否存在 -- **THEN** 网关 SHALL 更新模型的供应商关联 - -**变更说明:** 通过 ModelService、ModelRepository 和 ProviderRepository 实现。API 接口保持不变。 +- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,更新 provider_id 或 model_name 导致与已有记录重复 +- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict) ### Requirement: 删除模型配置 @@ -74,8 +75,6 @@ - **THEN** 网关 SHALL 删除模型记录 - **THEN** 网关 SHALL 返回状态码 204 (No Content) -**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。 - ### Requirement: 使用 service 层处理业务逻辑 Handler SHALL 通过 ModelService 处理业务逻辑。 @@ -85,25 +84,60 @@ Handler SHALL 通过 ModelService 处理业务逻辑。 - **WHEN** handler 收到请求 - **THEN** SHALL 调用对应的 ModelService 方法(Create、Get、List、Update、Delete) - **THEN** SHALL 使用 domain.Model 类型 +- **THEN** Create 时 SHALL 调用 `uuid.New()` 生成 id -#### Scenario: 供应商验证 +#### Scenario: 供应商验证和唯一性校验 - **WHEN** 创建或更新模型 - **THEN** SHALL 在 service 层验证供应商存在 -- **THEN** SHALL 通过 ProviderRepository 查询供应商 +- **THEN** SHALL 在 service 层验证 provider_id + model_name 联合唯一 + +### Requirement: 联合唯一约束并发处理 + +创建或更新模型时,SHALL 使用应用层校验 + 数据库约束双重保险处理联合唯一约束。 + +#### Scenario: 应用层快速失败 + +- **WHEN** 创建或更新模型前 +- **THEN** SHALL 先检查 provider_id + model_name 是否已存在 +- **THEN** 如已存在,SHALL 返回 HTTP 409 Conflict +- **THEN** SHALL 返回错误格式: + ```json + { + "error": "同一供应商下模型名称已存在", + "code": "duplicate_model" + } + ``` + +#### Scenario: 数据库约束兜底 + +- **WHEN** 并发创建导致应用层校验通过但数据库写入失败 +- **THEN** SHALL 捕获数据库 UNIQUE 约束错误 +- **THEN** SHALL 转换为 HTTP 409 Conflict 错误返回 +- **THEN** SHALL 返回错误格式: + ```json + { + "error": "同一供应商下模型名称已存在", + "code": "duplicate_model" + } + ``` + +#### Scenario: SQLite UNIQUE 约束错误检测 + +- **WHEN** 捕获数据库错误 +- **THEN** SHALL 检查错误信息是否包含 "UNIQUE constraint failed" +- **THEN** 如匹配,SHALL 识别为联合唯一约束冲突 ### Requirement: 使用 repository 层访问数据 Service SHALL 通过 ModelRepository 访问数据。 -#### Scenario: 调用 repository 方法 +#### Scenario: 联合查询 -- **WHEN** service 处理业务逻辑 -- **THEN** SHALL 调用对应的 ModelRepository 方法 -- **THEN** SHALL 使用 domain.Model 类型 +- **WHEN** service 需要按 provider 和 model_name 查询模型 +- **THEN** SHALL 调用 `FindByProviderAndModelName(providerID, modelName)` 方法 -#### Scenario: 数据验证 +#### Scenario: 查询所有启用模型 -- **WHEN** 创建或更新模型 -- **THEN** SHALL 在 service 层验证业务规则 -- **THEN** SHALL 在 repository 层执行数据库操作 +- **WHEN** proxy handler 需要聚合模型列表 +- **THEN** SHALL 调用 `ListEnabled()` 方法,返回所有 enabled 的模型(关联 enabled 的供应商) diff --git a/openspec/specs/protocol-adapter-anthropic/spec.md b/openspec/specs/protocol-adapter-anthropic/spec.md index 985e82b..acb7297 100644 --- a/openspec/specs/protocol-adapter-anthropic/spec.md +++ b/openspec/specs/protocol-adapter-anthropic/spec.md @@ -270,4 +270,73 @@ Decoder 几乎 1:1 映射,维护最小状态机: - **WHEN** interfaceType 为 EMBEDDINGS 或 RERANK - **THEN** supportsInterface SHALL 返回 false -- **THEN** 引擎 SHALL 走透传或返回空响应 \ No newline at end of file +- **THEN** 引擎 SHALL 走透传或返回空响应 +### Requirement: 模型详情路径识别 + +Anthropic 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。 + +#### Scenario: 含斜杠的统一模型 ID 路径 + +- **WHEN** 路径为 `/v1/models/anthropic/claude-3-opus` +- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` + +#### Scenario: 含多段斜杠的统一模型 ID 路径 + +- **WHEN** 路径为 `/v1/models/azure/deployments/gpt-4` +- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` + +#### Scenario: 模型列表路径不受影响 + +- **WHEN** 路径为 `/v1/models` +- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels` + +### Requirement: 提取统一模型 ID + +Anthropic 适配器 SHALL 从路径中提取统一模型 ID。 + +#### Scenario: 标准路径提取 + +- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/anthropic/claude-3-opus")` +- **THEN** SHALL 返回 `"anthropic/claude-3-opus"` + +#### Scenario: 复杂路径提取 + +- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")` +- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"` + +#### Scenario: 非模型详情路径 + +- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")` +- **THEN** SHALL 返回错误 + +### Requirement: 从请求体提取 model + +Anthropic 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。 + +#### Scenario: Chat 请求提取 model + +- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"anthropic/claude-3-opus","messages":[...]}` +- **THEN** SHALL 返回 `"anthropic/claude-3-opus"` + +#### Scenario: 无 model 字段 + +- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段 +- **THEN** SHALL 返回错误 + +### Requirement: 最小化改写请求体 model 字段 + +Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。 + +#### Scenario: 请求体 model 改写(统一 ID → 上游名) + +- **WHEN** 调用 `RewriteRequestModelName(body, "claude-3-opus-20240229", InterfaceTypeChat)` +- **THEN** SHALL 将请求体中 model 字段替换为 `"claude-3-opus-20240229"`,其余字段原样保留 + +### Requirement: 最小化改写响应体 model 字段 + +Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。 + +#### Scenario: 响应体 model 改写(上游名 → 统一 ID) + +- **WHEN** 调用 `RewriteResponseModelName(body, "anthropic/claude-3-opus", InterfaceTypeChat)` +- **THEN** SHALL 将响应体中 model 字段替换为 `"anthropic/claude-3-opus"`,其余字段原样保留 diff --git a/openspec/specs/protocol-adapter-openai/spec.md b/openspec/specs/protocol-adapter-openai/spec.md index e4c1928..7689ad8 100644 --- a/openspec/specs/protocol-adapter-openai/spec.md +++ b/openspec/specs/protocol-adapter-openai/spec.md @@ -269,4 +269,91 @@ Encoder SHALL 维护状态: #### Scenario: /rerank 接口 - **WHEN** 解码/编码 rerank 请求和响应 -- **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射 \ No newline at end of file +- **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射 +### Requirement: 模型详情路径识别 + +OpenAI 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。 + +#### Scenario: 含斜杠的统一模型 ID 路径 + +- **WHEN** 路径为 `/v1/models/openai/gpt-4` +- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` + +#### Scenario: 含多段斜杠的统一模型 ID 路径 + +- **WHEN** 路径为 `/v1/models/azure/accounts/org-123/models/gpt-4` +- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo` + +#### Scenario: 模型列表路径不受影响 + +- **WHEN** 路径为 `/v1/models` +- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels` + +### Requirement: 提取统一模型 ID + +OpenAI 适配器 SHALL 从路径中提取统一模型 ID。 + +#### Scenario: 标准路径提取 + +- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/openai/gpt-4")` +- **THEN** SHALL 返回 `"openai/gpt-4"` + +#### Scenario: 复杂路径提取 + +- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")` +- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"` + +#### Scenario: 非模型详情路径 + +- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")` +- **THEN** SHALL 返回错误 + +### Requirement: 从请求体提取 model + +OpenAI 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。 + +#### Scenario: Chat 请求提取 model + +- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...]}` +- **THEN** SHALL 返回 `"openai/gpt-4"` + +#### Scenario: Embedding 请求提取 model + +- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeEmbeddings)`,body 为 `{"model":"openai/text-embedding-3","input":"text"}` +- **THEN** SHALL 返回 `"openai/text-embedding-3"` + +#### Scenario: 无 model 字段 + +- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段 +- **THEN** SHALL 返回错误 + +### Requirement: 最小化改写请求体 model 字段 + +OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。 + +#### Scenario: 请求体 model 改写(统一 ID → 上游名) + +- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...],"some_param":"value"}` +- **THEN** SHALL 返回 `{"model":"gpt-4","messages":[...],"some_param":"value"}` +- **THEN** 除 model 外的字段 SHALL 保持原始 bytes 不变 + +#### Scenario: 不同 InterfaceType 的请求改写 + +- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeEmbeddings)` +- **THEN** SHALL 按 Embedding 接口的请求体 model 字段位置进行改写 +- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeRerank)` +- **THEN** SHALL 按 Rerank 接口的请求体 model 字段位置进行改写 + +### Requirement: 最小化改写响应体 model 字段 + +OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。 + +#### Scenario: 响应体 model 改写(上游名 → 统一 ID) + +- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeChat)`,body 为上游 Chat 响应 +- **THEN** SHALL 将 model 字段替换为 `"openai/gpt-4"`,其余字段原样保留 + +#### Scenario: 不同 InterfaceType 的响应改写 + +- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeEmbeddings)` +- **THEN** SHALL 按 Embedding 接口的响应体 model 字段位置进行改写 diff --git a/openspec/specs/provider-management/spec.md b/openspec/specs/provider-management/spec.md index 3ecf21f..47496e0 100644 --- a/openspec/specs/provider-management/spec.md +++ b/openspec/specs/provider-management/spec.md @@ -29,7 +29,44 @@ - **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) - **THEN** 错误 SHALL 指示缺少哪些字段 -**变更说明:** handler 通过 ProviderService 调用,数据访问通过 ProviderRepository。API 接口保持不变。 +#### Scenario: 创建供应商时 ID 包含非法字符 + +- **WHEN** 向 `/api/providers` 发送 POST 请求,id 包含非 `[a-zA-Z0-9_]` 字符 +- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) +- **THEN** 错误 SHALL 指示 id 仅允许字母、数字、下划线 + +#### Scenario: 创建供应商时 ID 过长 + +- **WHEN** 向 `/api/providers` 发送 POST 请求,id 长度超过 64 +- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) + +### Requirement: 供应商 ID 不允许修改 + +供应商 ID 是主键,用于构建统一模型 ID,不允许修改。 + +#### Scenario: 尝试修改供应商 ID + +- **WHEN** 向 `/api/providers/:id` 发送 PUT 请求,请求体中包含 `id` 字段 +- **THEN** ProviderService.Update SHALL 在 service 层校验并返回错误 +- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) +- **THEN** 错误 SHALL 指示供应商 ID 不允许修改 +- **THEN** 错误格式 SHALL 为: + ```json + { + "error": "供应商 ID 不允许修改", + "code": "IMMUTABLE_FIELD" + } + ``` + +**校验位置:** Service 层(`ProviderService.Update`) +- Service 层校验保证所有调用方(handler、CLI、未来 API)统一遵守 +- Handler 层负责捕获 `ErrImmutableField` 并转换为 HTTP 400 响应 + +**原因:** +- 供应商 ID 是主键,用于构建统一模型 ID(provider_id/model_name) +- 修改 ID 会导致所有统一模型 ID 失效 +- 客户端缓存的模型 ID 全部失效 +- 如需修改,应创建新供应商并迁移模型 ### Requirement: 列出所有供应商 diff --git a/openspec/specs/request-validation/spec.md b/openspec/specs/request-validation/spec.md index b22163c..e317b58 100644 --- a/openspec/specs/request-validation/spec.md +++ b/openspec/specs/request-validation/spec.md @@ -92,6 +92,44 @@ - **THEN** SHALL 验证至少提供一个可更新字段 - **THEN** SHALL 验证字段值有效性 +### Requirement: 供应商 ID 校验 + +创建供应商时,SHALL 对 `id` 字段进行字符集校验。 + +#### Scenario: 合法字符集 + +- **WHEN** 创建供应商,id 仅包含 `[a-zA-Z0-9_]` 字符 +- **THEN** SHALL 校验通过 + +#### Scenario: 非法字符 + +- **WHEN** 创建供应商,id 包含 `-`、`.`、`/`、空格、中文等非 `[a-zA-Z0-9_]` 字符 +- **THEN** SHALL 返回 400 错误 + +#### Scenario: 长度限制 + +- **WHEN** 创建供应商,id 长度超过 64 +- **THEN** SHALL 返回 400 错误 + +### Requirement: 模型创建校验 + +创建模型时,SHALL 对 `provider_id` + `model_name` 进行联合唯一性校验。 + +#### Scenario: 正常创建 + +- **WHEN** 创建模型,provider_id 存在且 provider_id + model_name 组合唯一 +- **THEN** SHALL 校验通过 + +#### Scenario: 联合唯一冲突 + +- **WHEN** 创建模型,provider_id + model_name 组合已存在 +- **THEN** SHALL 返回 409 错误 + +#### Scenario: model_name 为空 + +- **WHEN** 创建模型,未提供 model_name +- **THEN** SHALL 返回 400 错误 + ### Requirement: 返回友好的验证错误 系统 SHALL 返回友好的验证错误响应。 diff --git a/openspec/changes/unified-model-id/specs/unified-model-id/spec.md b/openspec/specs/unified-model-id/spec.md similarity index 93% rename from openspec/changes/unified-model-id/specs/unified-model-id/spec.md rename to openspec/specs/unified-model-id/spec.md index 8ff028f..b7bc35c 100644 --- a/openspec/changes/unified-model-id/specs/unified-model-id/spec.md +++ b/openspec/specs/unified-model-id/spec.md @@ -1,4 +1,10 @@ -## ADDED Requirements +# Unified Model ID + +## Purpose + +定义统一模型 ID 的格式、解析、格式化和校验规则,确保跨协议的模型标识一致性。 + +## Requirements ### Requirement: 解析统一模型 ID diff --git a/openspec/specs/unified-proxy-handler/spec.md b/openspec/specs/unified-proxy-handler/spec.md index c11fb8a..dc37fb6 100644 --- a/openspec/specs/unified-proxy-handler/spec.md +++ b/openspec/specs/unified-proxy-handler/spec.md @@ -105,4 +105,125 @@ ProxyHandler SHALL 支持 GET 请求的扩展层接口代理。 - **THEN** SHALL 调用 engine.convertHttpRequest(GET 请求 body 为空) - **THEN** SHALL 调用 providerClient.Send 发送请求 - **THEN** SHALL 调用 engine.convertHttpResponse 转换响应格式 -- **THEN** SHALL 返回转换后的响应 \ No newline at end of file +- **THEN** SHALL 返回转换后的响应 +### Requirement: 代理请求路由 + +ProxyHandler SHALL 使用统一模型 ID 路由所有代理请求。 + +#### Scenario: 提取统一模型 ID + +- **WHEN** 收到 Chat、Embeddings 或 Rerank 接口的 POST 请求(含请求体) +- **THEN** SHALL 调用客户端协议 adapter 的 `ExtractModelName(body, ifaceType)` 提取 model 值 +- **THEN** SHALL 调用 `ParseUnifiedModelID` 解析得到 providerID 和 modelName +- **THEN** SHALL 调用 `RoutingService.RouteByModelName(providerID, modelName)` 路由 + +#### Scenario: GET 请求或无请求体 + +- **WHEN** 收到 GET 请求或请求体为空或请求体中无法提取 model 字段 +- **THEN** SHALL 走 forwardPassthrough 透传到上游供应商(兼容未适配的客户端和无 body 请求) + +#### Scenario: 无效的统一模型 ID + +- **WHEN** 请求体中 `model` 字段不是有效的统一模型 ID 格式(不含 `/`) +- **THEN** SHALL 走 forwardPassthrough 透传到上游供应商(兼容使用原始模型名的客户端) + +#### Scenario: 模型不存在 + +- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的 provider_id + model_name 组合 +- **THEN** SHALL 返回错误响应,状态码为 404 + +#### Scenario: 模型已禁用 + +- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false +- **THEN** SHALL 返回错误响应,状态码为 404 + +#### Scenario: 供应商已禁用 + +- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false +- **THEN** SHALL 返回错误响应,状态码为 404 + +### Requirement: 同协议 Smart Passthrough + +当客户端协议与供应商协议相同时,ProxyHandler SHALL 使用 Smart Passthrough 处理 Chat、Embedding、Rerank 请求。 + +#### Scenario: 同协议非流式请求 + +- **WHEN** 客户端协议 == 供应商协议,且为非流式请求 +- **THEN** SHALL 调用 adapter 的 `RewriteRequestModelName(body, modelName, ifaceType)` 将请求体中 model 从统一 ID 改写为上游模型名 +- **THEN** SHALL 构建 URL 和 Headers(同当前透传逻辑) +- **THEN** SHALL 发送改写后的请求体到上游 +- **THEN** SHALL 调用 adapter 的 `RewriteResponseModelName(resp.Body, unifiedModelID, ifaceType)` 将响应中 model 从上游名改写为统一 ID +- **THEN** SHALL NOT 对 body 做全量 decode → encode,保持未改写字段的原始 bytes + +#### Scenario: 同协议流式请求 + +- **WHEN** 客户端协议 == 供应商协议,且为流式请求 +- **THEN** SHALL 对请求体做 `RewriteRequestModelName` 改写 model 字段 +- **THEN** SHALL 逐 SSE chunk 调用 `RewriteResponseModelName` 改写响应中 model 字段 +- **THEN** SHALL NOT 对 chunk 做全量 decode → encode + +#### Scenario: Smart Passthrough 保真性 + +- **WHEN** 客户端发送含未知参数的请求(如 `{"model":"openai/gpt-4","some_new_param":"value"}`) +- **THEN** 上游 SHALL 收到 `{"model":"gpt-4","some_new_param":"value"}` +- **THEN** `some_new_param` SHALL 保持原始值不变,不丢失、不改变类型 + +### Requirement: 跨协议完整转换 + +当客户端协议与供应商协议不同时,ProxyHandler SHALL 使用全量转换路径。 + +#### Scenario: 跨协议非流式请求 + +- **WHEN** 客户端协议 != 供应商协议 +- **THEN** SHALL 走 `ConvertHttpRequest` 全量转换,encoder 中 provider.ModelName 覆盖 model +- **THEN** SHALL 走 `ConvertHttpResponse` 全量转换,modelOverride 参数覆写 canonical.Model + +#### Scenario: 跨协议流式请求 + +- **WHEN** 客户端协议 != 供应商协议,且为流式请求 +- **THEN** SHALL 走 `CreateStreamConverter` 全量转换,modelOverride 参数覆写流式 canonical 事件中的 Model + +### Requirement: 模型列表本地聚合 + +ProxyHandler SHALL 从数据库聚合返回模型列表,不再透传上游。 + +#### Scenario: GET /v1/models + +- **WHEN** 收到 `GET /{protocol}/v1/models` 请求 +- **THEN** SHALL 从数据库查询所有 enabled 的模型(关联 enabled 的供应商) +- **THEN** SHALL 组装 `CanonicalModelList`,每个模型的 ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id +- **THEN** SHALL 使用客户端协议的 adapter 编码响应 +- **THEN** SHALL NOT 请求上游供应商 + +#### Scenario: 无可用模型 + +- **WHEN** 数据库中没有 enabled 的模型 +- **THEN** SHALL 返回空列表 + +### Requirement: 模型详情本地查询 + +ProxyHandler SHALL 从数据库查询返回模型详情,不再透传上游。 + +#### Scenario: GET /v1/models/{unified_id} + +- **WHEN** 收到 `GET /{protocol}/v1/models/{provider_id}/{model_name}` 请求 +- **THEN** SHALL 调用 adapter 的 `ExtractUnifiedModelID` 提取统一模型 ID +- **THEN** SHALL 解析统一模型 ID 得到 providerID 和 modelName +- **THEN** SHALL 从数据库查询对应的模型和供应商 +- **THEN** SHALL 组装 `CanonicalModelInfo`,ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id +- **THEN** SHALL 使用客户端协议的 adapter 编码响应 +- **THEN** SHALL NOT 请求上游供应商 + +#### Scenario: 模型详情不存在 + +- **WHEN** 统一模型 ID 对应的模型不存在或已禁用 +- **THEN** SHALL 返回错误响应,状态码为 404 + +### Requirement: 统计记录 + +ProxyHandler SHALL 使用 providerID 和 modelName 记录使用统计。 + +#### Scenario: 异步记录统计 + +- **WHEN** 代理请求成功完成 +- **THEN** SHALL 异步调用 `StatsService.Record(providerID, modelName)` diff --git a/openspec/specs/usage-statistics/spec.md b/openspec/specs/usage-statistics/spec.md index d1e58b3..7d7ff23 100644 --- a/openspec/specs/usage-statistics/spec.md +++ b/openspec/specs/usage-statistics/spec.md @@ -22,7 +22,20 @@ - **THEN** 网关 SHALL 增加该供应商和模型的请求计数 - **THEN** 网关 SHALL 在流结束后记录统计 -**变更说明:** 统计记录通过 StatsService 调用,数据访问通过 StatsRepository。API 接口保持不变。 +### Requirement: 使用统计记录统一模型标识 + +系统 SHALL 使用 providerID 和 modelName(上游模型名)记录使用统计。 + +#### Scenario: 代理请求统计记录 + +- **WHEN** 代理请求成功完成 +- **THEN** SHALL 记录 provider_id 和 model_name 到 usage_stats 表(参数来自路由结果) +- **THEN** SHALL 异步执行,不阻塞响应 + +#### Scenario: 查询统计 + +- **WHEN** 查询统计数据 +- **THEN** 支持按 provider_id 和 model_name 过滤 ### Requirement: 按供应商查询统计