1
0

feat: 实现统一模型 ID 机制

实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。

核心变更:
- 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID
- 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束
- Repository 层:FindByProviderAndModelName、ListEnabled 方法
- Service 层:联合唯一校验、provider ID 字符集校验
- Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法
- Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合
- 新增 error-responses、unified-model-id 规范

测试覆盖:
- 单元测试:modelid、conversion、handler、service、repository
- 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换
- 迁移测试:UUID 主键、UNIQUE 约束、级联删除

OpenSpec:
- 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id
- 同步 11 个 delta specs 到 main specs
- 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
2026-04-21 18:14:10 +08:00
parent 7f0f831226
commit 395887667d
73 changed files with 3360 additions and 1374 deletions

1
.gitignore vendored
View File

@@ -405,3 +405,4 @@ openspec/changes/archive
temp temp
.agents .agents
skills-lock.json skills-lock.json
.worktrees

View File

@@ -38,10 +38,11 @@ nex/
## 功能特性 ## 功能特性
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议 - **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
- **透明代理**:对 OpenAI 兼容供应商透传请求 - **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`
- **透明代理**:对 OpenAI 兼容供应商 Smart Passthrough最小化改写保持参数保真
- **流式响应**:完整支持 SSE 流式传输 - **流式响应**:完整支持 SSE 流式传输
- **Function Calling**支持工具调用Tools - **Function Calling**支持工具调用Tools
- **多供应商管理**:配置和管理多个供应商 - **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
- **用量统计**:按供应商、模型、日期统计请求数量 - **用量统计**:按供应商、模型、日期统计请求数量
- **Web 配置界面**:提供供应商和模型配置管理 - **Web 配置界面**:提供供应商和模型配置管理
@@ -99,23 +100,27 @@ bun dev
### 代理接口(对外部应用) ### 代理接口(对外部应用)
代理接口统一使用 `provider_id/model_name` 格式的模型 ID`openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
- `POST /v1/chat/completions` - OpenAI Chat Completions API - `POST /v1/chat/completions` - OpenAI Chat Completions API
- `POST /v1/messages` - Anthropic Messages API - `POST /v1/messages` - Anthropic Messages API
- `GET /v1/models` - 模型列表(本地数据库聚合,不请求上游)
- `GET /v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
### 管理接口(对前端) ### 管理接口(对前端)
#### 供应商管理 #### 供应商管理
- `GET /api/providers` - 列出所有供应商 - `GET /api/providers` - 列出所有供应商
- `POST /api/providers` - 创建供应商 - `POST /api/providers` - 创建供应商`id` 仅限字母、数字、下划线,长度 1-64
- `GET /api/providers/:id` - 获取供应商 - `GET /api/providers/:id` - 获取供应商
- `PUT /api/providers/:id` - 更新供应商 - `PUT /api/providers/:id` - 更新供应商`id` 不可修改)
- `DELETE /api/providers/:id` - 删除供应商 - `DELETE /api/providers/:id` - 删除供应商
#### 模型管理 #### 模型管理
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤) - `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
- `POST /api/models` - 创建模型 - `POST /api/models` - 创建模型`id` 由系统自动生成 UUID`provider_id` + `model_name` 联合唯一)
- `GET /api/models/:id` - 获取模型 - `GET /api/models/:id` - 获取模型(响应含 `unified_id` 字段,格式 `provider_id/model_name`
- `PUT /api/models/:id` - 更新模型 - `PUT /api/models/:id` - 更新模型(不可修改 `id`
- `DELETE /api/models/:id` - 删除模型 - `DELETE /api/models/:id` - 删除模型
#### 统计查询 #### 统计查询

View File

@@ -108,12 +108,13 @@ backend/
│ │ ├── logger.go │ │ ├── logger.go
│ │ ├── rotate.go │ │ ├── rotate.go
│ │ └── context.go │ │ └── context.go
│ ├── modelid/ # 统一模型 ID 工具包
│ │ ├── model_id.go
│ │ └── model_id_test.go
│ └── validator/ # 验证器 │ └── validator/ # 验证器
│ └── validator.go │ └── validator.go
├── migrations/ # 数据库迁移 ├── migrations/ # 数据库迁移
── 20260401000001_initial_schema.sql ── 20260421000001_initial_schema.sql
│ ├── 20260401000002_add_indexes.sql
│ └── 20260419000001_add_provider_protocol.sql
├── tests/ # 集成测试 ├── tests/ # 集成测试
│ ├── helpers.go │ ├── helpers.go
│ └── integration/ │ └── integration/
@@ -292,6 +293,8 @@ GET /anthropic/v1/models
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。 **协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
### 管理接口 ### 管理接口
#### 供应商管理 #### 供应商管理
@@ -324,14 +327,30 @@ GET /anthropic/v1/models
- `PUT /api/models/:id` - 更新模型 - `PUT /api/models/:id` - 更新模型
- `DELETE /api/models/:id` - 删除模型 - `DELETE /api/models/:id` - 删除模型
**创建请求**id 由系统自动生成 UUID
```json ```json
{ {
"id": "gpt-4",
"provider_id": "openai", "provider_id": "openai",
"model_name": "gpt-4" "model_name": "gpt-4"
} }
``` ```
**响应示例**
```json
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"provider_id": "openai",
"model_name": "gpt-4",
"unified_id": "openai/gpt-4",
"enabled": true,
"created_at": "2026-04-21T00:00:00Z"
}
```
**统一模型 ID**`unified_id` 字段为 `provider_id/model_name` 格式,用于代理请求的 `model` 参数。
#### 统计查询 #### 统计查询
- `GET /api/stats` - 查询统计 - `GET /api/stats` - 查询统计

View File

@@ -68,7 +68,7 @@ func main() {
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
// 5. 初始化 service 层 // 5. 初始化 service 层
providerService := service.NewProviderService(providerRepo) providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo) modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo) routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo) statsService := service.NewStatsService(statsRepo)

View File

@@ -17,11 +17,11 @@ type Provider struct {
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"` Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
} }
// Model 模型配置 // Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
type Model struct { type Model struct {
ID string `gorm:"primaryKey" json:"id"` ID string `gorm:"primaryKey" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"` ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"` ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"`
Enabled bool `gorm:"default:true" json:"enabled"` Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
} }

View File

@@ -40,6 +40,12 @@ type ProtocolAdapter interface {
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
// 统一模型 ID 相关方法
ExtractUnifiedModelID(nativePath string) (string, error)
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
} }
// AdapterRegistry 适配器注册表接口 // AdapterRegistry 适配器注册表接口

View File

@@ -2,6 +2,7 @@ package anthropic
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
} }
} }
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id} // isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool { func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") { if !strings.HasPrefix(path, "/v1/models/") {
return false return false
} }
suffix := path[len("/v1/models/"):] suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/") return suffix != ""
} }
// BuildUrl 根据接口类型构建 URL // BuildUrl 根据接口类型构建 URL
@@ -203,3 +204,74 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口") return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
} }
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
default:
return body, nil
}
}

View File

@@ -0,0 +1,263 @@
package anthropic
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ExtractUnifiedModelID
// ---------------------------------------------------------------------------
func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", id)
})
t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
require.NoError(t, err)
assert.Equal(t, "some/deep/nested/model", id)
})
t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
require.NoError(t, err)
assert.Equal(t, "claude-3", id)
})
t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/messages")
require.Error(t, err)
})
t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err)
})
t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
t.Run("unrelated_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/other")
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestExtractModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", model)
})
t.Run("chat_with_max_tokens", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3-opus", model)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
require.Error(t, err)
})
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteRequestModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestRewriteRequestModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "claude-3", m["model"])
msgs, ok := m["messages"]
require.True(t, ok)
msgsArr, ok := msgs.([]interface{})
require.True(t, ok)
assert.Len(t, msgsArr, 0)
})
t.Run("preserves_unknown_fields", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "claude-3", m["model"])
assert.Equal(t, 0.7, m["temperature"])
// max_tokens is encoded as float in JSON numbers
maxTokens, ok := m["max_tokens"]
require.True(t, ok)
assert.Equal(t, float64(1024), maxTokens)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"anthropic/claude-3"}`)
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteResponseModelName (Chat only for Anthropic)
// ---------------------------------------------------------------------------
func TestRewriteResponseModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat_existing_model", func(t *testing.T) {
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "anthropic/claude-3", m["model"])
// other fields preserved
_, hasContent := m["content"]
assert.True(t, hasContent)
assert.Equal(t, "end_turn", m["stop_reason"])
})
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "anthropic/claude-3", m["model"])
})
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
body := []byte(`{"model":"claude-3"}`)
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
require.NoError(t, err)
assert.Equal(t, string(body), string(rewritten))
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName and RewriteRequest consistency
// ---------------------------------------------------------------------------
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
a := NewAdapter()
t.Run("chat_round_trip", func(t *testing.T) {
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
// Extract the unified model ID from the body
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "anthropic/claude-3", extracted)
// Rewrite to the native model name
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
require.NoError(t, err)
// Extract again from the rewritten body to verify the same location was targeted
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "claude-3", afterRewrite)
// Verify other fields are preserved
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, float64(1024), m["max_tokens"])
})
}
// ---------------------------------------------------------------------------
// isModelInfoPath (additional unified model ID cases)
// ---------------------------------------------------------------------------
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"simple_model_id", "/v1/models/claude-3", true},
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
{"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/v1/models/", false},
{"messages_path", "/v1/messages", false},
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}

View File

@@ -79,11 +79,29 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Smart Passthrough: 同协议时最小化改写 model 字段
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
rewrittenBody := spec.Body
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
if len(spec.Body) > 0 && provider.ModelName != "" {
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
zap.String("error", err.Error()),
zap.String("interface", string(interfaceType)))
rewrittenBody = spec.Body
}
}
}
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + nativePath, URL: provider.BaseURL + nativePath,
Method: spec.Method, Method: spec.Method,
Headers: providerAdapter.BuildHeaders(provider), Headers: providerAdapter.BuildHeaders(provider),
Body: spec.Body, Body: rewrittenBody,
}, nil }, nil
} }
@@ -112,9 +130,30 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}, nil }, nil
} }
// ConvertHttpResponse 转换 HTTP 响应 // ConvertHttpResponse 转换 HTTP 响应modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) { func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return &spec, nil
}
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.String("error", err.Error()),
zap.String("interface", string(interfaceType)))
return &spec, nil
}
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
return &spec, nil return &spec, nil
} }
@@ -127,7 +166,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
return nil, err return nil, err
} }
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body) convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -139,9 +178,17 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
}, nil }, nil
} }
// CreateStreamConverter 创建流式转换器 // CreateStreamConverter 创建流式转换器modelOverride 用于跨协议场景覆写 model 字段
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) { func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return NewPassthroughStreamConverter(), nil
}
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
return NewPassthroughStreamConverter(), nil return NewPassthroughStreamConverter(), nil
} }
@@ -167,6 +214,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
ctx, ctx,
clientProtocol, clientProtocol,
providerProtocol, providerProtocol,
modelOverride,
), nil ), nil
} }
@@ -192,11 +240,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte
} }
} }
// convertResponseBody 转换响应体 // convertResponseBody 转换响应体modelOverride 非空时在 canonical 层面覆写 Model 字段
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
switch interfaceType { switch interfaceType {
case InterfaceTypeChat: case InterfaceTypeChat:
return e.convertChatResponseBody(clientAdapter, providerAdapter, body) return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeModels: case InterfaceTypeModels:
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) { if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
return body, nil return body, nil
@@ -211,12 +259,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) { if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
return body, nil return body, nil
} }
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body) return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
case InterfaceTypeRerank: case InterfaceTypeRerank:
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) { if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
return body, nil return body, nil
} }
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body) return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
default: default:
return body, nil return body, nil
} }
@@ -241,11 +289,14 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
return encoded, nil return encoded, nil
} }
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
canonicalResp, err := providerAdapter.DecodeResponse(body) canonicalResp, err := providerAdapter.DecodeResponse(body)
if err != nil { if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err) return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
} }
if modelOverride != "" {
canonicalResp.Model = modelOverride
}
encoded, err := clientAdapter.EncodeResponse(canonicalResp) encoded, err := clientAdapter.EncodeResponse(canonicalResp)
if err != nil { if err != nil {
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err) return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
@@ -290,12 +341,15 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
return providerAdapter.EncodeEmbeddingRequest(req, provider) return providerAdapter.EncodeEmbeddingRequest(req, provider)
} }
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body) resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
return body, nil return body, nil
} }
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeEmbeddingResponse(resp) return clientAdapter.EncodeEmbeddingResponse(resp)
} }
@@ -308,11 +362,14 @@ func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter Prot
return providerAdapter.EncodeRerankRequest(req, provider) return providerAdapter.EncodeRerankRequest(req, provider)
} }
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body) resp, err := providerAdapter.DecodeRerankResponse(body)
if err != nil { if err != nil {
return body, nil return body, nil
} }
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp) return clientAdapter.EncodeRerankResponse(resp)
} }

View File

@@ -113,7 +113,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`), StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
}, "client", "provider", InterfaceTypeChat) }, "client", "provider", InterfaceTypeChat, "")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode) assert.Equal(t, 200, result.StatusCode)
assert.Contains(t, string(result.Body), "resp-1") assert.Contains(t, string(result.Body), "resp-1")
@@ -129,7 +129,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
_ = engine.RegisterAdapter(providerAdapter) _ = engine.RegisterAdapter(providerAdapter)
_ = engine.RegisterAdapter(newMockAdapter("client", false)) _ = engine.RegisterAdapter(newMockAdapter("client", false))
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat) _, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
assert.Error(t, err) assert.Error(t, err)
} }
@@ -189,7 +189,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`), StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeEmbeddings) }, "client", "provider", InterfaceTypeEmbeddings, "")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, result) assert.NotNil(t, result)
} }
@@ -207,7 +207,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`), StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
}, "client", "provider", InterfaceTypeRerank) }, "client", "provider", InterfaceTypeRerank, "")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, result) assert.NotNil(t, result)
} }
@@ -242,7 +242,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`), StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
}, "client", "provider", InterfaceTypeModels) }, "client", "provider", InterfaceTypeModels, "")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, result) assert.NotNil(t, result)
} }
@@ -259,7 +259,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{ result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`), StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
}, "client", "provider", InterfaceTypeModelInfo) }, "client", "provider", InterfaceTypeModelInfo, "")
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, result) assert.NotNil(t, result)
} }

View File

@@ -13,16 +13,18 @@ import (
// mockProtocolAdapter 模拟协议适配器 // mockProtocolAdapter 模拟协议适配器
type mockProtocolAdapter struct { type mockProtocolAdapter struct {
protocolName string protocolName string
passthrough bool passthrough bool
ifaceType InterfaceType ifaceType InterfaceType
supportsIface map[InterfaceType]bool supportsIface map[InterfaceType]bool
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error) decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error) encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error) decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error) encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
streamDecoderFn func() StreamDecoder streamDecoderFn func() StreamDecoder
streamEncoderFn func() StreamEncoder streamEncoderFn func() StreamEncoder
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
} }
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter { func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
@@ -155,6 +157,28 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera
return json.Marshal(resp) return json.Marshal(resp)
} }
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
return "", nil
}
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
return "", nil
}
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
if m.rewriteReqFn != nil {
return m.rewriteReqFn(body, newModel, ifaceType)
}
return body, nil
}
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
if m.rewriteRespFn != nil {
return m.rewriteRespFn(body, newModel, ifaceType)
}
return body, nil
}
// noopStreamDecoder 空流式解码器 // noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{} type noopStreamDecoder struct{}
@@ -309,7 +333,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
Body: []byte(`{"id":"123"}`), Body: []byte(`{"id":"123"}`),
} }
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat) result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 200, result.StatusCode) assert.Equal(t, 200, result.StatusCode)
assert.Equal(t, spec.Body, result.Body) assert.Equal(t, spec.Body, result.Body)
@@ -320,7 +344,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true)) _ = engine.RegisterAdapter(newMockAdapter("openai", true))
converter, err := engine.CreateStreamConverter("openai", "openai") converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
require.NoError(t, err) require.NoError(t, err)
_, ok := converter.(*PassthroughStreamConverter) _, ok := converter.(*PassthroughStreamConverter)
assert.True(t, ok) assert.True(t, ok)
@@ -332,7 +356,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
_ = engine.RegisterAdapter(newMockAdapter("client", false)) _ = engine.RegisterAdapter(newMockAdapter("client", false))
_ = engine.RegisterAdapter(newMockAdapter("provider", false)) _ = engine.RegisterAdapter(newMockAdapter("provider", false))
converter, err := engine.CreateStreamConverter("client", "provider") converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
require.NoError(t, err) require.NoError(t, err)
_, ok := converter.(*CanonicalStreamConverter) _, ok := converter.(*CanonicalStreamConverter)
assert.True(t, ok) assert.True(t, ok)
@@ -380,3 +404,230 @@ func TestRegistry_GetNonExistent(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "未找到适配器") assert.Contains(t, err.Error(), "未找到适配器")
} }
// ============ modelOverride 测试 ============
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
return json.Marshal(map[string]any{"model": resp.Model})
}
_ = engine.RegisterAdapter(clientAdapter)
providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
}
_ = engine.RegisterAdapter(providerAdapter)
spec := HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"model":"native-model"}`),
}
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
require.NoError(t, err)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(result.Body, &resp))
assert.Equal(t, "provider/gpt-4", resp["model"])
}
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
_ = engine.RegisterAdapter(openaiAdapter)
spec := HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
}
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
require.NoError(t, err)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(result.Body, &resp))
assert.Equal(t, "openai/gpt-4", resp["model"])
assert.Equal(t, "resp-1", resp["id"])
}
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
_ = engine.RegisterAdapter(openaiAdapter)
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
require.NoError(t, err)
_, ok := converter.(*SmartPassthroughStreamConverter)
assert.True(t, ok)
// 验证 chunk 改写
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
require.Len(t, chunks, 1)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(chunks[0], &resp))
assert.Equal(t, "openai/gpt-4", resp["model"])
}
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
// provider adapter 解码出含 model 的流式事件
providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder {
return &engineTestStreamDecoder{
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
return []canonical.CanonicalStreamEvent{
canonical.NewMessageStartEvent("msg-1", "native-model"),
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
canonical.NewMessageStopEvent(),
}
},
}
}
_ = engine.RegisterAdapter(providerAdapter)
// client adapter 编码时输出 model 字段
clientAdapter := newMockAdapter("client", false)
clientAdapter.streamEncoderFn = func() StreamEncoder {
return &engineTestStreamEncoder{
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
if event.Message != nil {
data, _ := json.Marshal(map[string]string{
"type": string(event.Type),
"model": event.Message.Model,
})
return [][]byte{data}
}
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
return [][]byte{data}
},
}
}
_ = engine.RegisterAdapter(clientAdapter)
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
require.NoError(t, err)
// 验证类型是 CanonicalStreamConverter
_, ok := converter.(*CanonicalStreamConverter)
assert.True(t, ok)
// 处理一个 chunk验证 model 被覆写为统一模型 ID
chunks := converter.ProcessChunk([]byte("raw"))
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
var startEvent map[string]string
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
}
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder {
return &engineTestStreamDecoder{
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
return []canonical.CanonicalStreamEvent{
canonical.NewMessageStartEvent("msg-1", "native-model"),
}
},
}
}
_ = engine.RegisterAdapter(providerAdapter)
clientAdapter := newMockAdapter("client", false)
clientAdapter.streamEncoderFn = func() StreamEncoder {
return &engineTestStreamEncoder{
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
if event.Message != nil {
data, _ := json.Marshal(map[string]string{
"model": event.Message.Model,
})
return [][]byte{data}
}
return nil
},
}
}
_ = engine.RegisterAdapter(clientAdapter)
// modelOverride 为空,不应覆写
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
require.NoError(t, err)
chunks := converter.ProcessChunk([]byte("raw"))
require.Len(t, chunks, 1)
var resp map[string]string
require.NoError(t, json.Unmarshal(chunks[0], &resp))
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
}
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test
type engineTestStreamDecoder struct {
processFn func([]byte) []canonical.CanonicalStreamEvent
flushFn func() []canonical.CanonicalStreamEvent
}
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
if d.processFn != nil {
return d.processFn(raw)
}
return nil
}
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil {
return d.flushFn()
}
return nil
}
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test
type engineTestStreamEncoder struct {
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
flushFn func() [][]byte
}
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
if e.encodeFn != nil {
return e.encodeFn(event)
}
return nil
}
func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil {
return e.flushFn()
}
return nil
}

View File

@@ -2,6 +2,7 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -43,13 +44,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
} }
} }
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id} // isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool { func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/v1/models/") { if !strings.HasPrefix(path, "/v1/models/") {
return false return false
} }
suffix := path[len("/v1/models/"):] suffix := path[len("/v1/models/"):]
return suffix != "" && !strings.Contains(suffix, "/") return suffix != ""
} }
// BuildUrl 根据接口类型构建 URL // BuildUrl 根据接口类型构建 URL
@@ -216,3 +217,80 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) { func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
return encodeRerankResponse(resp) return encodeRerankResponse(resp)
} }
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
return suffix, nil
}
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return "", nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
raw, exists := m["model"]
if !exists {
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
}
var current string
if err := json.Unmarshal(raw, &current); err != nil {
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
}
return current, rewriteFunc, nil
default:
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
}
}
// ExtractModelName 从请求体中提取 model 值
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
model, _, err := locateModelFieldInRequest(body, ifaceType)
return model, err
}
// RewriteRequestModelName 最小化改写请求体中的 model 字段
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
if err != nil {
return nil, err
}
return rewriteFunc(newModel)
}
// RewriteResponseModelName 最小化改写响应体中的 model 字段
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
var m map[string]json.RawMessage
if err := json.Unmarshal(body, &m); err != nil {
return nil, err
}
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
return json.Marshal(m)
case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel)
}
return json.Marshal(m)
default:
return body, nil
}
}

View File

@@ -121,7 +121,7 @@ func TestIsModelInfoPath(t *testing.T) {
{"model_info", "/v1/models/gpt-4", true}, {"model_info", "/v1/models/gpt-4", true},
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true}, {"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
{"models_list", "/v1/models", false}, {"models_list", "/v1/models", false},
{"nested_path", "/v1/models/gpt-4/versions", false}, {"nested_path", "/v1/models/gpt-4/versions", true},
{"empty_suffix", "/v1/models/", false}, {"empty_suffix", "/v1/models/", false},
{"unrelated", "/v1/chat/completions", false}, {"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/v1/model", false}, {"partial_prefix", "/v1/model", false},

View File

@@ -0,0 +1,360 @@
package openai
import (
"encoding/json"
"testing"
"nex/backend/internal/conversion"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ExtractUnifiedModelID
// ---------------------------------------------------------------------------
func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", id)
})
t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
})
t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", id)
})
t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
require.Error(t, err)
})
t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err)
})
t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
t.Run("unrelated_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/other")
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName
// ---------------------------------------------------------------------------
func TestExtractModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", model)
})
t.Run("embedding", func(t *testing.T) {
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "openai/text-embedding", model)
})
t.Run("rerank", func(t *testing.T) {
body := []byte(`{"model":"openai/rerank","query":"test"}`)
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "openai/rerank", model)
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4"}`)
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteRequestModelName
// ---------------------------------------------------------------------------
func TestRewriteRequestModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "gpt-4", m["model"])
// messages field preserved
msgs, ok := m["messages"]
require.True(t, ok)
msgsArr, ok := msgs.([]interface{})
require.True(t, ok)
assert.Len(t, msgsArr, 0)
})
t.Run("preserves_unknown_fields", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "gpt-4", m["model"])
assert.Equal(t, 0.7, m["temperature"])
})
t.Run("embedding", func(t *testing.T) {
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "text-embedding", m["model"])
assert.Equal(t, "hello", m["input"])
})
t.Run("rerank", func(t *testing.T) {
body := []byte(`{"model":"openai/rerank","query":"test"}`)
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "rerank", m["model"])
assert.Equal(t, "test", m["query"])
})
t.Run("no_model_field", func(t *testing.T) {
body := []byte(`{"messages":[]}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
t.Run("unsupported_interface_type", func(t *testing.T) {
body := []byte(`{"model":"openai/gpt-4"}`)
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// RewriteResponseModelName
// ---------------------------------------------------------------------------
func TestRewriteResponseModelName(t *testing.T) {
a := NewAdapter()
t.Run("chat_existing_model", func(t *testing.T) {
body := []byte(`{"model":"gpt-4","choices":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/gpt-4", m["model"])
choices, ok := m["choices"]
require.True(t, ok)
choicesArr, ok := choices.([]interface{})
require.True(t, ok)
assert.Len(t, choicesArr, 0)
})
t.Run("chat_without_model_field", func(t *testing.T) {
body := []byte(`{"choices":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/gpt-4", m["model"])
choices, ok := m["choices"]
require.True(t, ok)
choicesArr, ok := choices.([]interface{})
require.True(t, ok)
assert.Len(t, choicesArr, 0)
})
t.Run("rerank_existing_model", func(t *testing.T) {
body := []byte(`{"model":"rerank","results":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/rerank", m["model"])
})
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
body := []byte(`{"results":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
_, hasModel := m["model"]
assert.False(t, hasModel, "rerank response without model field should not have one added")
})
t.Run("embedding_existing_model", func(t *testing.T) {
body := []byte(`{"model":"text-embedding","data":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/text-embedding", m["model"])
})
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
body := []byte(`{"data":[]}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, "openai/text-embedding", m["model"])
})
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
body := []byte(`{"model":"gpt-4"}`)
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
require.NoError(t, err)
assert.Equal(t, string(body), string(rewritten))
})
t.Run("invalid_json", func(t *testing.T) {
body := []byte(`{invalid}`)
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
require.Error(t, err)
})
}
// ---------------------------------------------------------------------------
// ExtractModelName and RewriteRequest consistency
// ---------------------------------------------------------------------------
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
a := NewAdapter()
t.Run("chat_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
// Extract the unified model ID from the body
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", extracted)
// Rewrite to the native model name
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
require.NoError(t, err)
// Extract again from the rewritten body to verify the same location was targeted
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
require.NoError(t, err)
assert.Equal(t, "gpt-4", afterRewrite)
// Verify other fields are preserved
var m map[string]interface{}
require.NoError(t, json.Unmarshal(rewritten, &m))
assert.Equal(t, 0.7, m["temperature"])
})
t.Run("embedding_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "openai/text-embedding", extracted)
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
require.NoError(t, err)
assert.Equal(t, "text-embedding", afterRewrite)
})
t.Run("rerank_round_trip", func(t *testing.T) {
original := []byte(`{"model":"openai/rerank","query":"test"}`)
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "openai/rerank", extracted)
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
require.NoError(t, err)
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
require.NoError(t, err)
assert.Equal(t, "rerank", afterRewrite)
})
}
// ---------------------------------------------------------------------------
// isModelInfoPath (additional unified model ID cases)
// ---------------------------------------------------------------------------
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{"simple_model_id", "/v1/models/gpt-4", true},
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
{"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/v1/models/", false},
{"chat_completions", "/v1/chat/completions", false},
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
})
}
}

View File

@@ -38,14 +38,52 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
return nil return nil
} }
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
// 逐 chunk 改写 model 字段
type SmartPassthroughStreamConverter struct {
adapter ProtocolAdapter
modelOverride string
interfaceType InterfaceType
}
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
return &SmartPassthroughStreamConverter{
adapter: adapter,
modelOverride: modelOverride,
interfaceType: interfaceType,
}
}
// ProcessChunk 改写 chunk 中的 model 字段
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
if len(rawChunk) == 0 {
return nil
}
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
if err != nil {
// 改写失败,返回原始 chunk
return [][]byte{rawChunk}
}
return [][]byte{rewrittenChunk}
}
// Flush 无缓冲数据
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
return nil
}
// CanonicalStreamConverter 跨协议规范流式转换器 // CanonicalStreamConverter 跨协议规范流式转换器
type CanonicalStreamConverter struct { type CanonicalStreamConverter struct {
decoder StreamDecoder decoder StreamDecoder
encoder StreamEncoder encoder StreamEncoder
chain *MiddlewareChain chain *MiddlewareChain
ctx ConversionContext ctx ConversionContext
clientProtocol string clientProtocol string
providerProtocol string providerProtocol string
modelOverride string
} }
// NewCanonicalStreamConverter 创建规范流式转换器 // NewCanonicalStreamConverter 创建规范流式转换器
@@ -57,18 +95,19 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *
} }
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器 // NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter { func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
return &CanonicalStreamConverter{ return &CanonicalStreamConverter{
decoder: decoder, decoder: decoder,
encoder: encoder, encoder: encoder,
chain: chain, chain: chain,
ctx: ctx, ctx: ctx,
clientProtocol: clientProtocol, clientProtocol: clientProtocol,
providerProtocol: providerProtocol, providerProtocol: providerProtocol,
modelOverride: modelOverride,
} }
} }
// ProcessChunk 解码 → 中间件 → 编码管道 // ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
events := c.decoder.ProcessChunk(rawChunk) events := c.decoder.ProcessChunk(rawChunk)
var result [][]byte var result [][]byte
@@ -80,6 +119,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
} }
events[i] = *processed events[i] = *processed
} }
c.applyModelOverride(&events[i])
chunks := c.encoder.EncodeEvent(events[i]) chunks := c.encoder.EncodeEvent(events[i])
result = append(result, chunks...) result = append(result, chunks...)
} }
@@ -98,6 +138,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
} }
events[i] = *processed events[i] = *processed
} }
c.applyModelOverride(&events[i])
chunks := c.encoder.EncodeEvent(events[i]) chunks := c.encoder.EncodeEvent(events[i])
result = append(result, chunks...) result = append(result, chunks...)
} }
@@ -105,3 +146,10 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
result = append(result, encoderChunks...) result = append(result, encoderChunks...)
return result return result
} }
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
if c.modelOverride != "" && event.Message != nil {
event.Message.Model = c.modelOverride
}
}

View File

@@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
chain.Use(&recordingMiddleware{name: "mw1", records: &records}) chain.Use(&recordingMiddleware{name: "mw1", records: &records})
ctx := NewConversionContext(InterfaceTypeChat) ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic") converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.ProcessChunk([]byte("raw")) result := converter.ProcessChunk([]byte("raw"))
assert.Len(t, result, 1) assert.Len(t, result, 1)
@@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
chain.Use(&errorMiddleware{}) chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat) ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic") converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.ProcessChunk([]byte("raw")) result := converter.ProcessChunk([]byte("raw"))
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)") assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
@@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
chain.Use(&errorMiddleware{}) chain.Use(&errorMiddleware{})
ctx := NewConversionContext(InterfaceTypeChat) ctx := NewConversionContext(InterfaceTypeChat)
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic") converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
result := converter.Flush() result := converter.Flush()
assert.Len(t, result, 1) assert.Len(t, result, 1)

View File

@@ -1,8 +1,12 @@
package domain package domain
import "time" import (
"time"
// Model 模型领域模型 "nex/backend/pkg/modelid"
)
// Model 模型领域模型id 为 UUID 自动生成)
type Model struct { type Model struct {
ID string `json:"id"` ID string `json:"id"`
ProviderID string `json:"provider_id"` ProviderID string `json:"provider_id"`
@@ -10,3 +14,8 @@ type Model struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
} }
// UnifiedModelID 返回统一模型 ID格式provider_id/model_name
func (m *Model) UnifiedModelID() string {
return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName)
}

View File

@@ -113,7 +113,6 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
h := NewModelHandler(&mockModelService{}) h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"id": "m1",
"provider_id": "p1", "provider_id": "p1",
"model_name": "gpt-4", "model_name": "gpt-4",
}) })
@@ -127,7 +126,7 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
var result domain.Model var result domain.Model
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result)) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID) assert.NotEmpty(t, result.ID)
} }
func TestModelHandler_GetModel(t *testing.T) { func TestModelHandler_GetModel(t *testing.T) {

View File

@@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
@@ -31,7 +30,7 @@ type mockRoutingService struct {
err error err error
} }
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) { func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err return m.result, m.err
} }
@@ -57,6 +56,14 @@ type mockProviderService struct {
err error err error
} }
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err } func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) { func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err return m.provider, m.err
@@ -73,13 +80,21 @@ type mockModelService struct {
err error err error
} }
func (m *mockModelService) Create(model *domain.Model) error { return m.err } func (m *mockModelService) Create(model *domain.Model) error {
if m.err == nil {
model.ID = "mock-uuid-1234"
}
return m.err
}
func (m *mockModelService) Get(id string) (*domain.Model, error) { func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err return m.model, m.err
} }
func (m *mockModelService) List(providerID string) ([]domain.Model, error) { func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err return m.models, m.err
} }
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
return []domain.Model{}, nil
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error { func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err return m.err
} }
@@ -163,8 +178,8 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
func TestModelHandler_ListModels(t *testing.T) { func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{ h := NewModelHandler(&mockModelService{
models: []domain.Model{ models: []domain.Model{
{ID: "m1", ModelName: "gpt-4"}, {ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
{ID: "m2", ModelName: "gpt-3.5"}, {ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
}, },
}) })
@@ -174,6 +189,72 @@ func TestModelHandler_ListModels(t *testing.T) {
h.ListModels(c) h.ListModels(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
var result []modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
require.Len(t, result, 2)
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
}
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
h.GetModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "m1", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{})
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 201, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "mock-uuid-1234", result.ID)
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
}
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"},
})
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "m1"}}
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateModel(c)
assert.Equal(t, 200, w.Code)
var result modelResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
} }
// ============ Stats Handler 测试 ============ // ============ Stats Handler 测试 ============
@@ -256,7 +337,7 @@ func formatMapErrors(errs map[string]string) string {
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) { func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h := NewProviderHandler(&mockProviderService{ h := NewProviderHandler(&mockProviderService{
err: gorm.ErrDuplicatedKey, err: appErrors.ErrConflict,
}) })
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{

View File

@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"errors"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -22,23 +23,35 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
return &ModelHandler{modelService: modelService} return &ModelHandler{modelService: modelService}
} }
// modelResponse 模型响应 DTO扩展 unified_id 字段
type modelResponse struct {
domain.Model
UnifiedModelID string `json:"unified_id"`
}
// newModelResponse 从 domain.Model 构造响应 DTO
func newModelResponse(m *domain.Model) modelResponse {
return modelResponse{
Model: *m,
UnifiedModelID: m.UnifiedModelID(),
}
}
// CreateModel 创建模型 // CreateModel 创建模型
func (h *ModelHandler) CreateModel(c *gin.Context) { func (h *ModelHandler) CreateModel(c *gin.Context) {
var req struct { var req struct {
ID string `json:"id" binding:"required"`
ProviderID string `json:"provider_id" binding:"required"` ProviderID string `json:"provider_id" binding:"required"`
ModelName string `json:"model_name" binding:"required"` ModelName string `json:"model_name" binding:"required"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "缺少必需字段: id, provider_id, model_name", "error": "缺少必需字段: provider_id, model_name",
}) })
return return
} }
model := &domain.Model{ model := &domain.Model{
ID: req.ID,
ProviderID: req.ProviderID, ProviderID: req.ProviderID,
ModelName: req.ModelName, ModelName: req.ModelName,
} }
@@ -51,11 +64,18 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
}) })
return return
} }
if err == appErrors.ErrDuplicateModel {
c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err) writeError(c, err)
return return
} }
c.JSON(http.StatusCreated, model) c.JSON(http.StatusCreated, newModelResponse(model))
} }
// ListModels 列出模型 // ListModels 列出模型
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, models) resp := make([]modelResponse, len(models))
for i, m := range models {
resp[i] = newModelResponse(&m)
}
c.JSON(http.StatusOK, resp)
} }
// GetModel 获取模型 // GetModel 获取模型
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, model) c.JSON(http.StatusOK, newModelResponse(model))
} }
// UpdateModel 更新模型 // UpdateModel 更新模型
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
err := h.modelService.Update(id, req) err := h.modelService.Update(id, req)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, appErrors.ErrModelNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })
return return
} }
if err == appErrors.ErrProviderNotFound { if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在", "error": "供应商不存在",
}) })
return return
} }
if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{
"error": appErrors.ErrDuplicateModel.Message,
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
writeError(c, err) writeError(c, err)
return return
} }
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, model) c.JSON(http.StatusOK, newModelResponse(model))
} }
// DeleteModel 删除模型 // DeleteModel 删除模型

View File

@@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider) err := h.providerService.Create(provider)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) { if err == appErrors.ErrInvalidProviderID {
c.JSON(http.StatusConflict, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商 ID 已存在", "error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code,
}) })
return return
} }
@@ -119,6 +120,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
}) })
return return
} }
if errors.Is(err, appErrors.ErrImmutableField) {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrImmutableField.Message,
"code": appErrors.ErrImmutableField.Code,
})
return
}
writeError(c, err) writeError(c, err)
return return
} }

View File

@@ -11,9 +11,11 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/service" "nex/backend/internal/service"
"nex/backend/pkg/modelid"
) )
// ProxyHandler 统一代理处理器 // ProxyHandler 统一代理处理器
@@ -54,6 +56,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
} }
nativePath := "/v1/" + path nativePath := "/v1/" + path
// 获取 client adapter
registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
return
}
// 检测接口类型
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
// 处理 Models 接口:本地聚合
if ifaceType == conversion.InterfaceTypeModels {
h.handleModelsList(c, clientAdapter)
return
}
// 处理 ModelInfo 接口:本地查询
if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
return
}
h.handleModelInfo(c, unifiedID, clientAdapter)
return
}
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
@@ -61,10 +91,17 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
return return
} }
// 解析 model 名称(从 JSON body 中提取GET 请求无 body // 解析统一模型 ID使用 adapter.ExtractModelName
modelName := "" var providerID, modelName string
if len(body) > 0 { if len(body) > 0 {
modelName = extractModelName(body) unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
} }
// 构建输入 HTTPRequestSpec // 构建输入 HTTPRequestSpec
@@ -76,7 +113,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
} }
// 路由 // 路由
routeResult, err := h.routingService.Route(modelName) routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil { if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游 // GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" { if len(body) == 0 || modelName == "" {
@@ -94,24 +131,30 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
} }
// 构建 TargetProvider // 构建 TargetProvider
// 注意ModelName 字段用于 Smart Passthrough 场景改写请求体
// 同协议:请求体中的统一 ID 会被改写为 ModelName上游名
// 跨协议:全量转换时 ModelName 会被编码到请求体中
targetProvider := conversion.NewTargetProvider( targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL, routeResult.Provider.BaseURL,
routeResult.Provider.APIKey, routeResult.Provider.APIKey,
routeResult.Model.ModelName, routeResult.Model.ModelName, // 上游模型名,用于请求改写
) )
// 判断是否流式 // 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath) isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID()
if isStream { if isStream {
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
} else { } else {
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult) h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
} }
} }
// handleNonStream 处理非流式请求 // handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求 // 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil { if err != nil {
@@ -128,9 +171,8 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
return return
} }
// 转换响应 // 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
if err != nil { if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
@@ -153,7 +195,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
} }
// handleStream 处理流式请求 // handleStream 处理流式请求
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) { func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求 // 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil { if err != nil {
@@ -161,8 +203,8 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return return
} }
// 创建流式转换器 // 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol) streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
@@ -224,6 +266,79 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
return req.Stream return req.Stream
} }
// handleModelsList 处理 GET /v1/models 本地聚合
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
// 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
return
}
// 构建 CanonicalModelList
modelList := &canonical.CanonicalModelList{
Models: make([]canonical.CanonicalModel, 0, len(models)),
}
for _, m := range models {
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
ID: m.UnifiedModelID(),
Name: m.ModelName,
Created: m.CreatedAt.Unix(),
OwnedBy: m.ProviderID,
})
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
// 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
return
}
// 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
return
}
// 构建 CanonicalModelInfo
modelInfo := &canonical.CanonicalModelInfo{
ID: model.UnifiedModelID(),
Name: model.ModelName,
Created: model.CreatedAt.Unix(),
OwnedBy: model.ProviderID,
}
// 使用 adapter 编码返回
body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeConversionError 写入转换错误 // writeConversionError 写入转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok { if convErr, ok := err.(*conversion.ConversionError); ok {
@@ -292,7 +407,7 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
return return
} }
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
@@ -307,17 +422,6 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body) c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
} }
// extractModelName 从 JSON body 中提取 model
func extractModelName(body []byte) string {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
return ""
}
return req.Model
}
// extractHeaders 从 Gin context 提取请求头 // extractHeaders 从 Gin context 提取请求头
func extractHeaders(c *gin.Context) map[string]string { func extractHeaders(c *gin.Context) map[string]string {
headers := make(map[string]string) headers := make(map[string]string)

View File

@@ -60,13 +60,23 @@ type mockProxyRoutingService struct {
err error err error
} }
func (m *mockProxyRoutingService) Route(modelName string) (*domain.RouteResult, error) { func (m *mockProxyRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err return m.result, m.err
} }
type mockProxyProviderService struct { type mockProxyProviderService struct {
providers []domain.Provider providers []domain.Provider
err error err error
enabledModels []domain.Model
modelByProvName *domain.Model
}
func (m *mockProxyProviderService) ListEnabledModels() ([]domain.Model, error) {
return m.enabledModels, nil
}
func (m *mockProxyProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return m.modelByProvName, nil
} }
func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil } func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil }
@@ -319,7 +329,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 404, w.Code) // Models 接口现在本地聚合,返回空列表 200
assert.Equal(t, 200, w.Code)
} }
func TestExtractHeaders(t *testing.T) { func TestExtractHeaders(t *testing.T) {
@@ -716,58 +727,6 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
// ============ extractModelName 测试 ============
func TestExtractModelName(t *testing.T) {
tests := []struct {
name string
body []byte
expected string
}{
{
name: "valid model",
body: []byte(`{"model": "gpt-4", "messages": []}`),
expected: "gpt-4",
},
{
name: "empty body",
body: []byte(`{}`),
expected: "",
},
{
name: "invalid json",
body: []byte(`{invalid}`),
expected: "",
},
{
name: "nested structure",
body: []byte(`{"model": "claude-3", "messages": [{"role": "user", "content": "hello"}]}`),
expected: "claude-3",
},
{
name: "model with special chars",
body: []byte(`{"model": "gpt-4-0125-preview", "stream": true}`),
expected: "gpt-4-0125-preview",
},
{
name: "empty body bytes",
body: []byte{},
expected: "",
},
{
name: "model is null",
body: []byte(`{"model": null}`),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractModelName(tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ isStreamRequest 测试 ============ // ============ isStreamRequest 测试 ============
func TestIsStreamRequest(t *testing.T) { func TestIsStreamRequest(t *testing.T) {
@@ -831,3 +790,270 @@ func TestIsStreamRequest(t *testing.T) {
}) })
} }
} }
// ============ Models / ModelInfo 本地聚合测试 ============
func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
enabledModels: []domain.Model{
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3", Enabled: true},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data, ok := resp["data"].([]interface{})
require.True(t, ok)
assert.Len(t, data, 2)
// 验证统一模型 ID 格式
first := data[0].(map[string]interface{})
assert.Equal(t, "openai/gpt-4", first["id"])
}
func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
modelByProvName: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai/gpt-4", resp["id"])
}
func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testing.T) {
engine := setupProxyEngine(t)
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Body: []byte(`{"object":"list","data":[]}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, &mockProxyRoutingService{err: appErrors.ErrModelNotFound}, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
}
// ============ Smart Passthrough 统一模型 ID 路由测试 ============
func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
// 验证请求体中的 model 已被改写为上游模型名
var req map[string]interface{}
json.Unmarshal(spec.Body, &req)
assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 客户端发送统一模型 ID
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证响应中的 model 已被改写为统一模型 ID
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai_p/gpt-4", resp["model"])
}
// ============ 跨协议统一模型 ID 路由测试 ============
func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[{"type":"text","text":"Hello"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":3}}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// OpenAI 客户端使用统一模型 ID 路由到 Anthropic 供应商
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证跨协议转换后响应中的 model 被覆写为统一模型 ID
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "anthropic_p/claude-3", resp["model"])
}
func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
},
}
client := &mockProxyProviderClient{
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte(`event: message_start
data: {"type":"message_start","message":{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[]}}
`)}
ch <- provider.StreamEvent{Data: []byte(`event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
`)}
ch <- provider.StreamEvent{Data: []byte(`event: message_stop
data: {"type":"message_stop"}
`)}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
body := w.Body.String()
// 验证跨协议流式中 model 被覆写为统一模型 ID
assert.Contains(t, body, "anthropic_p/claude-3", "跨协议流式响应中 model 应被覆写为统一模型 ID")
}
func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{
result: &domain.RouteResult{
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
},
}
var capturedRequestBody []byte
client := &mockProxyProviderClient{
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
capturedRequestBody = spec.Body
return &conversion.HTTPResponseSpec{
StatusCode: 200,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"unknown_field":"preserved"}`),
}, nil
},
}
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 包含未知参数,验证 Smart Passthrough 保真性
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
// 验证请求中 model 被改写为上游模型名,但未知参数保留
var reqBody map[string]interface{}
require.NoError(t, json.Unmarshal(capturedRequestBody, &reqBody))
assert.Equal(t, "gpt-4", reqBody["model"], "请求中 model 应被改写为上游模型名")
assert.Equal(t, "should_be_preserved", reqBody["custom_param"], "Smart Passthrough 应保留未知参数")
// 验证响应中 model 被改写为统一模型 ID但未知参数保留
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "openai_p/gpt-4", resp["model"], "响应中 model 应被改写为统一模型 ID")
assert.Equal(t, "preserved", resp["unknown_field"], "Smart Passthrough 应保留未知响应字段")
}
func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
providerSvc := &mockProxyProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
},
}
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, providerSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
// 使用统一模型 ID 格式但模型不存在
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "error")
}

View File

@@ -7,7 +7,8 @@ type ModelRepository interface {
Create(model *domain.Model) error Create(model *domain.Model) error
GetByID(id string) (*domain.Model, error) GetByID(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error) List(providerID string) ([]domain.Model, error)
GetByModelName(modelName string) (*domain.Model, error) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
ListEnabled() ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error Update(id string, updates map[string]interface{}) error
Delete(id string) error Delete(id string) error
} }

View File

@@ -52,9 +52,9 @@ func (r *modelRepository) List(providerID string) ([]domain.Model, error) {
return result, nil return result, nil
} }
func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) { func (r *modelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
var m config.Model var m config.Model
err := r.db.Where("model_name = ?", modelName).First(&m).Error err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&m).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -62,6 +62,21 @@ func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error
return &d, nil return &d, nil
} }
func (r *modelRepository) ListEnabled() ([]domain.Model, error) {
var models []config.Model
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
Where("models.enabled = ? AND providers.enabled = ?", true, true).
Find(&models).Error
if err != nil {
return nil, err
}
result := make([]domain.Model, len(models))
for i := range models {
result[i] = toDomainModel(&models[i])
}
return result, nil
}
func (r *modelRepository) Update(id string, updates map[string]interface{}) error { func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates) result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
if result.Error != nil { if result.Error != nil {

View File

@@ -9,4 +9,7 @@ type ProviderRepository interface {
List() ([]domain.Provider, error) List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error Update(id string, updates map[string]interface{}) error
Delete(id string) error Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
} }

View File

@@ -71,6 +71,25 @@ func (r *providerRepository) Delete(id string) error {
return nil return nil
} }
// ListEnabledModels 返回所有启用的模型(关联启用的供应商)
func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) {
var models []domain.Model
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
Where("models.enabled = ? AND providers.enabled = ?", true, true).
Find(&models).Error
return models, err
}
// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型
func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
var model domain.Model
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
func toDomainProvider(p *config.Provider) domain.Provider { func toDomainProvider(p *config.Provider) domain.Provider {
return domain.Provider{ return domain.Provider{
ID: p.ID, ID: p.ID,

View File

@@ -147,15 +147,36 @@ func TestModelRepository_GetByID(t *testing.T) {
assert.Equal(t, "gpt-4", result.ModelName) assert.Equal(t, "gpt-4", result.ModelName)
} }
func TestModelRepository_GetByModelName(t *testing.T) { func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
repo := NewModelRepository(db) repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
result, err := repo.GetByModelName("gpt-4") result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "m1", result.ID) assert.Equal(t, "m1", result.ID)
assert.Equal(t, "p1", result.ProviderID)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
// Wrong provider_id
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
assert.Error(t, err)
// Wrong model_name
_, err = repo.FindByProviderAndModelName("p1", "gpt-3.5")
assert.Error(t, err)
// Both wrong
_, err = repo.FindByProviderAndModelName("p2", "claude-3")
assert.Error(t, err)
} }
func TestModelRepository_List(t *testing.T) { func TestModelRepository_List(t *testing.T) {
@@ -175,6 +196,54 @@ func TestModelRepository_List(t *testing.T) {
assert.Len(t, p1Models, 2) assert.Len(t, p1Models, 2)
} }
func TestModelRepository_ListEnabled(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
modelRepo := NewModelRepository(db)
// Create two providers (both start enabled due to gorm:"default:true")
err := providerRepo.Create(&domain.Provider{
ID: "enabled-provider", Name: "Enabled Provider",
APIKey: "key1", BaseURL: "https://enabled.com", Enabled: true,
})
require.NoError(t, err)
err = providerRepo.Create(&domain.Provider{
ID: "disabled-provider", Name: "Disabled Provider",
APIKey: "key2", BaseURL: "https://disabled.com", Enabled: true,
})
require.NoError(t, err)
// Disable the second provider via Update (GORM default:true skips zero values on Create)
err = providerRepo.Update("disabled-provider", map[string]interface{}{"enabled": false})
require.NoError(t, err)
// Create models (all start enabled due to gorm:"default:true")
err = modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "enabled-provider", ModelName: "gpt-4", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m2", ProviderID: "enabled-provider", ModelName: "gpt-3.5", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m3", ProviderID: "disabled-provider", ModelName: "claude-3", Enabled: true})
require.NoError(t, err)
err = modelRepo.Create(&domain.Model{ID: "m4", ProviderID: "disabled-provider", ModelName: "claude-3.5", Enabled: true})
require.NoError(t, err)
// Disable m2 via Update
err = modelRepo.Update("m2", map[string]interface{}{"enabled": false})
require.NoError(t, err)
// ListEnabled should only return models where both model and provider are enabled:
// - m1: enabled provider + enabled model -> returned
// - m2: enabled provider + disabled model -> filtered out
// - m3: disabled provider + enabled model -> filtered out
// - m4: disabled provider + enabled model -> filtered out
enabled, err := modelRepo.ListEnabled()
require.NoError(t, err)
require.Len(t, enabled, 1)
assert.Equal(t, "m1", enabled[0].ID)
assert.Equal(t, "enabled-provider", enabled[0].ProviderID)
assert.Equal(t, "gpt-4", enabled[0].ModelName)
}
func TestModelRepository_Update(t *testing.T) { func TestModelRepository_Update(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
repo := NewModelRepository(db) repo := NewModelRepository(db)

View File

@@ -7,6 +7,7 @@ type ModelService interface {
Create(model *domain.Model) error Create(model *domain.Model) error
Get(id string) (*domain.Model, error) Get(id string) (*domain.Model, error)
List(providerID string) ([]domain.Model, error) List(providerID string) ([]domain.Model, error)
ListEnabled() ([]domain.Model, error)
Update(id string, updates map[string]interface{}) error Update(id string, updates map[string]interface{}) error
Delete(id string) error Delete(id string) error
} }

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"github.com/google/uuid"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
@@ -17,11 +18,18 @@ func NewModelService(modelRepo repository.ModelRepository, providerRepo reposito
} }
func (s *modelService) Create(model *domain.Model) error { func (s *modelService) Create(model *domain.Model) error {
// Verify provider exists // 校验供应商存在
_, err := s.providerRepo.GetByID(model.ProviderID) if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
if err != nil {
return appErrors.ErrProviderNotFound return appErrors.ErrProviderNotFound
} }
// 联合唯一校验:同一供应商下 model_name 不重复
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
return err
}
// 自动生成 UUID 作为 id
model.ID = uuid.New().String()
model.Enabled = true model.Enabled = true
return s.modelRepo.Create(model) return s.modelRepo.Create(model)
} }
@@ -34,17 +42,57 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) {
return s.modelRepo.List(providerID) return s.modelRepo.List(providerID)
} }
func (s *modelService) ListEnabled() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
func (s *modelService) Update(id string, updates map[string]interface{}) error { func (s *modelService) Update(id string, updates map[string]interface{}) error {
// If updating provider_id, verify new provider exists // 获取当前模型
current, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
// 如果更新 provider_id校验新供应商存在
if providerID, ok := updates["provider_id"].(string); ok { if providerID, ok := updates["provider_id"].(string); ok {
_, err := s.providerRepo.GetByID(providerID) if _, err := s.providerRepo.GetByID(providerID); err != nil {
if err != nil {
return appErrors.ErrProviderNotFound return appErrors.ErrProviderNotFound
} }
} }
// 确定更新后的 provider_id 和 model_name
newProviderID := current.ProviderID
if v, ok := updates["provider_id"].(string); ok {
newProviderID = v
}
newModelName := current.ModelName
if v, ok := updates["model_name"].(string); ok {
newModelName = v
}
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
if newProviderID != current.ProviderID || newModelName != current.ModelName {
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
return err
}
}
return s.modelRepo.Update(id, updates) return s.modelRepo.Update(id, updates)
} }
func (s *modelService) Delete(id string) error { func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id) return s.modelRepo.Delete(id)
} }
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
// excludeID 用于更新时排除自身
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil // 未找到,不重复
}
if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身
}
return appErrors.ErrDuplicateModel
}

View File

@@ -9,4 +9,7 @@ type ProviderService interface {
List() ([]domain.Provider, error) List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error Update(id string, updates map[string]interface{}) error
Delete(id string) error Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error)
} }

View File

@@ -1,21 +1,35 @@
package service package service
import ( import (
"strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
) )
type providerService struct { type providerService struct {
providerRepo repository.ProviderRepository providerRepo repository.ProviderRepository
modelRepo repository.ModelRepository
} }
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService { func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
return &providerService{providerRepo: providerRepo} return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
} }
func (s *providerService) Create(provider *domain.Provider) error { func (s *providerService) Create(provider *domain.Provider) error {
// 校验 provider_id 字符集
if err := modelid.ValidateProviderID(provider.ID); err != nil {
return appErrors.ErrInvalidProviderID
}
provider.Enabled = true provider.Enabled = true
return s.providerRepo.Create(provider) err := s.providerRepo.Create(provider)
if err != nil && isUniqueConstraintError(err) {
return appErrors.ErrConflict
}
return err
} }
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) { func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
@@ -41,9 +55,31 @@ func (s *providerService) List() ([]domain.Provider, error) {
} }
func (s *providerService) Update(id string, updates map[string]interface{}) error { func (s *providerService) Update(id string, updates map[string]interface{}) error {
if _, ok := updates["id"]; ok {
return appErrors.ErrImmutableField
}
return s.providerRepo.Update(id, updates) return s.providerRepo.Update(id, updates)
} }
func (s *providerService) Delete(id string) error { func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id) return s.providerRepo.Delete(id)
} }
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
func (s *providerService) ListEnabledModels() ([]domain.Model, error) {
return s.modelRepo.ListEnabled()
}
// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询)
func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return s.modelRepo.FindByProviderAndModelName(providerID, modelName)
}
// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误
func isUniqueConstraintError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate")
}

View File

@@ -4,5 +4,5 @@ import "nex/backend/internal/domain"
// RoutingService 路由服务接口 // RoutingService 路由服务接口
type RoutingService interface { type RoutingService interface {
Route(modelName string) (*domain.RouteResult, error) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
} }

View File

@@ -16,8 +16,8 @@ func NewRoutingService(modelRepo repository.ModelRepository, providerRepo reposi
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo} return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
} }
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) { func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.GetByModelName(modelName) model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil { if err != nil {
return nil, appErrors.ErrModelNotFound return nil, appErrors.ErrModelNotFound
} }

View File

@@ -13,7 +13,8 @@ import (
func TestProviderService_Update(t *testing.T) { func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo) modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}) svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
@@ -28,7 +29,8 @@ func TestProviderService_Update(t *testing.T) {
func TestProviderService_Update_NotFound(t *testing.T) { func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) repo := repository.NewProviderRepository(db)
svc := NewProviderService(repo) modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"}) err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err) assert.Error(t, err)
@@ -41,11 +43,12 @@ func TestModelService_Get(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
model, err := svc.Get("m1") result, err := svc.Get(model.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName) assert.Equal(t, "gpt-4", result.ModelName)
} }
func TestModelService_Update(t *testing.T) { func TestModelService_Update(t *testing.T) {
@@ -55,14 +58,15 @@ func TestModelService_Update(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"}) err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"})
require.NoError(t, err) require.NoError(t, err)
model, err := svc.Get("m1") result, err := svc.Get(model.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "gpt-4o", model.ModelName) assert.Equal(t, "gpt-4o", result.ModelName)
} }
func TestModelService_Update_ProviderID_Invalid(t *testing.T) { func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
@@ -72,9 +76,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"}) err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"})
assert.Error(t, err) assert.Error(t, err)
} }
@@ -85,12 +90,13 @@ func TestModelService_Delete(t *testing.T) {
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
err := svc.Delete("m1") err := svc.Delete(model.ID)
require.NoError(t, err) require.NoError(t, err)
_, err = svc.Get("m1") _, err = svc.Get(model.ID)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@@ -1,8 +1,10 @@
package service package service
import ( import (
"errors"
"testing" "testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
@@ -11,6 +13,7 @@ import (
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
) )
func setupServiceTestDB(t *testing.T) *gorm.DB { func setupServiceTestDB(t *testing.T) *gorm.DB {
@@ -29,80 +32,106 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
return db return db
} }
// ============ ProviderService 测试 ============ // ============ RoutingService - RouteByModelName 测试 ============
func TestProviderService_Create(t *testing.T) { func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
svc := NewProviderService(repo) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
provider := &domain.Provider{ // 创建供应商和模型
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
} modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
err := svc.Create(provider)
result, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.True(t, provider.Enabled) assert.Equal(t, "openai", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
} }
func TestProviderService_Get_MaskKey(t *testing.T) { func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
svc := NewProviderService(repo) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{ _, err := svc.RouteByModelName("openai", "nonexistent-model")
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com", assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
})
result, err := svc.Get("p1", true)
require.NoError(t, err)
assert.Equal(t, "***2345", result.APIKey)
result, err = svc.Get("p1", false)
require.NoError(t, err)
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
} }
func TestProviderService_List(t *testing.T) { func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
svc := NewProviderService(repo) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"}) // 创建启用的供应商和禁用的模型
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"}) providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
providers, err := svc.List() _, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err) assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
assert.Len(t, providers, 2)
assert.Contains(t, providers[0].APIKey, "***")
} }
func TestProviderService_Delete(t *testing.T) { func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
svc := NewProviderService(repo) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}) // 创建启用的供应商和模型,然后禁用供应商
err := svc.Delete("p1") providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
require.NoError(t, err) modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
_, err = svc.Get("p1", false) _, err := svc.RouteByModelName("openai", "gpt-4")
assert.Error(t, err) assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
} }
// ============ ModelService 测试 ============ // ============ ModelService - Create with UUID 测试 ============
func TestModelService_Create(t *testing.T) { func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model) err := svc.Create(model)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, model.Enabled)
// 验证返回的 model 拥有有效的 UUID
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
// 通过 Get 验证持久化
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
assert.Equal(t, "gpt-4", stored.ModelName)
}
func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
} }
func TestModelService_Create_ProviderNotFound(t *testing.T) { func TestModelService_Create_ProviderNotFound(t *testing.T) {
@@ -111,160 +140,135 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"} model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model) err := svc.Create(model)
assert.Error(t, err) assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
} }
func TestModelService_List(t *testing.T) { // ============ ProviderService - Create with validation 测试 ============
func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
}
func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
assert.True(t, provider.Enabled)
}
// ============ ModelService - Update with duplicate check 测试 ============
func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}) providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}) providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
models, err := svc.List("p1") model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, models, 2)
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
err = svc.Create(model2)
require.NoError(t, err)
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
err = svc.Update(model2.ID, map[string]interface{}{
"provider_id": "openai",
"model_name": "gpt-4",
})
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
} }
// ============ RoutingService 测试 ============ func TestModelService_Update_ModelNotFound(t *testing.T) {
func TestRoutingService_Route(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true}) err := svc.Update("nonexistent-id", map[string]interface{}{
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}) "model_name": "gpt-4",
})
result, err := svc.Route("gpt-4") assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
require.NoError(t, err)
assert.Equal(t, "p1", result.Provider.ID)
assert.Equal(t, "gpt-4", result.Model.ModelName)
} }
func TestRoutingService_Route_ModelNotFound(t *testing.T) { func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo) svc := NewModelService(modelRepo, providerRepo)
_, err := svc.Route("nonexistent-model") providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
assert.Error(t, err)
}
func TestRoutingService_Route_ModelDisabled(t *testing.T) { model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
db := setupServiceTestDB(t) err := svc.Create(model)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
// 先创建启用的模型,然后通过 Update 禁用
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
// 先创建启用的 provider然后禁用
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
_, err := svc.Route("gpt-4")
assert.Error(t, err)
}
// ============ StatsService 测试 ============
func TestStatsService_RecordAndGet(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
svc := NewStatsService(statsRepo)
err := svc.Record("p1", "gpt-4")
require.NoError(t, err) require.NoError(t, err)
stats, err := svc.Get("p1", "", nil, nil) // 更新 model_name 为不冲突的值
err = svc.Update(model.ID, map[string]interface{}{
"model_name": "gpt-4-turbo",
})
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, stats, 1)
updated, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
} }
func TestStatsService_Aggregate_ByProvider(t *testing.T) { // ============ ProviderService - Update immutable ID 测试 ============
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
stats := []domain.UsageStats{ func TestProviderService_Update_ImmutableID(t *testing.T) {
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10}, db := setupServiceTestDB(t)
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5}, repo := repository.NewProviderRepository(db)
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8}, modelRepo := repository.NewModelRepository(db)
} svc := NewProviderService(repo, modelRepo)
result := svc.Aggregate(stats, "provider") provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
assert.Len(t, result, 2) err := svc.Create(provider)
require.NoError(t, err)
p1Count := 0 // 尝试更新 id 字段
p2Count := 0 err = svc.Update("openai", map[string]interface{}{
for _, r := range result { "id": "new-id",
if r["provider_id"] == "p1" { })
p1Count = r["request_count"].(int) assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
}
if r["provider_id"] == "p2" {
p2Count = r["request_count"].(int)
}
}
assert.Equal(t, 15, p1Count)
assert.Equal(t, 8, p2Count)
} }
func TestStatsService_Aggregate_ByDate(t *testing.T) { func TestProviderService_Update_Success(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil) db := setupServiceTestDB(t)
svc := NewStatsService(statsRepo) repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
stats := []domain.UsageStats{ provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
{ProviderID: "p1", RequestCount: 10}, err := svc.Create(provider)
{ProviderID: "p2", RequestCount: 5}, require.NoError(t, err)
}
result := svc.Aggregate(stats, "date") // 更新 name
assert.Len(t, result, 1) err = svc.Update("openai", map[string]interface{}{
assert.Equal(t, 15, result[0]["request_count"]) "name": "OpenAI Updated",
} })
require.NoError(t, err)
func TestStatsService_Aggregate_ByModel(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil) updated, err := svc.Get("openai", false)
svc := NewStatsService(statsRepo) require.NoError(t, err)
assert.Equal(t, "OpenAI Updated", updated.Name)
stats := []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
}
result := svc.Aggregate(stats, "model")
assert.Len(t, result, 3)
// 验证每个 provider/model 组合的计数
counts := make(map[string]int)
for _, r := range result {
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
counts[key] = r["request_count"].(int)
}
assert.Equal(t, 13, counts["openai/gpt-4"])
assert.Equal(t, 5, counts["openai/gpt-3.5"])
assert.Equal(t, 8, counts["anthropic/claude-3"])
} }

View File

@@ -1,9 +0,0 @@
-- +goose Up
CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id);
CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name);
CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date);
-- +goose Down
DROP INDEX IF EXISTS idx_usage_stats_provider_model_date;
DROP INDEX IF EXISTS idx_models_model_name;
DROP INDEX IF EXISTS idx_models_provider_id;

View File

@@ -1,6 +0,0 @@
-- +goose Up
ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai';
-- +goose Down
-- SQLite 不支持 DROP COLUMN3.35.0 之前),但 goose 的 Down 通常不需要
CREATE TABLE providers_backup AS SELECT id, name, api_key, base_url, enabled, created_at, updated_at FROM providers;

View File

@@ -1,9 +1,13 @@
-- +goose Up -- +goose Up
-- 统一初始迁移providers、models、usage_stats 完整表结构
-- models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 联合唯一约束
CREATE TABLE IF NOT EXISTS providers ( CREATE TABLE IF NOT EXISTS providers (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
api_key TEXT NOT NULL, api_key TEXT NOT NULL,
base_url TEXT NOT NULL, base_url TEXT NOT NULL,
protocol TEXT DEFAULT 'openai',
enabled INTEGER DEFAULT 1, enabled INTEGER DEFAULT 1,
created_at DATETIME, created_at DATETIME,
updated_at DATETIME updated_at DATETIME
@@ -15,7 +19,8 @@ CREATE TABLE IF NOT EXISTS models (
model_name TEXT NOT NULL, model_name TEXT NOT NULL,
enabled INTEGER DEFAULT 1, enabled INTEGER DEFAULT 1,
created_at DATETIME, created_at DATETIME,
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE,
UNIQUE(provider_id, model_name)
); );
CREATE TABLE IF NOT EXISTS usage_stats ( CREATE TABLE IF NOT EXISTS usage_stats (
@@ -27,7 +32,14 @@ CREATE TABLE IF NOT EXISTS usage_stats (
UNIQUE(provider_id, model_name, date) UNIQUE(provider_id, model_name, date)
); );
CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id);
CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name);
CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date);
-- +goose Down -- +goose Down
DROP INDEX IF EXISTS idx_usage_stats_provider_model_date;
DROP INDEX IF EXISTS idx_models_model_name;
DROP INDEX IF EXISTS idx_models_provider_id;
DROP TABLE IF EXISTS usage_stats; DROP TABLE IF EXISTS usage_stats;
DROP TABLE IF EXISTS models; DROP TABLE IF EXISTS models;
DROP TABLE IF EXISTS providers; DROP TABLE IF EXISTS providers;

View File

@@ -49,17 +49,20 @@ func NewAppError(code, message string, httpStatus int) *AppError {
// Predefined errors // Predefined errors
var ( var (
ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound) ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound)
ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound) ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound)
ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound) ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound)
ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound) ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound)
ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest) ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest)
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError) ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError) ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict) ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError) ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError)
ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway) ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway)
ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway) ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway)
ErrInvalidProviderID = NewAppError("invalid_provider_id", "供应商 ID 仅允许字母、数字、下划线,长度 1-64", http.StatusBadRequest)
ErrDuplicateModel = NewAppError("duplicate_model", "同一供应商下模型名称已存在", http.StatusConflict)
ErrImmutableField = NewAppError("immutable_field", "供应商 ID 不允许修改", http.StatusBadRequest)
) )
// AsAppError 尝试将 error 转换为 *AppError // AsAppError 尝试将 error 转换为 *AppError

View File

@@ -9,6 +9,14 @@ import (
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
) )
// stdoutWriter 包装 os.Stdout忽略 Sync() 错误。
// 在非 TTY 环境(如 go testos.Stdout 被重定向为 pipe
// 底层 fsync 会返回 "bad file descriptor"。zap 社区标准做法。
type stdoutWriter struct{}
func (stdoutWriter) Write(p []byte) (int, error) { return os.Stdout.Write(p) }
func (stdoutWriter) Sync() error { return nil }
// Config 日志配置 // Config 日志配置
type Config struct { type Config struct {
Level string // 日志级别: debug, info, warn, error Level string // 日志级别: debug, info, warn, error
@@ -46,7 +54,7 @@ func New(cfg Config) (*zap.Logger, error) {
stdoutCore := zapcore.NewCore( stdoutCore := zapcore.NewCore(
stdoutEncoder, stdoutEncoder,
zapcore.AddSync(os.Stdout), zapcore.AddSync(stdoutWriter{}),
level, level,
) )

View File

@@ -0,0 +1,63 @@
package modelid
import (
"errors"
"regexp"
"strings"
)
var providerIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
// ParseUnifiedModelID 将 "provider_id/model_name" 格式的字符串解析为 providerID 和 modelName
// 在第一个 "/" 处分割model_name 可以包含 "/"
func ParseUnifiedModelID(id string) (providerID, modelName string, err error) {
if id == "" {
return "", "", errors.New("统一模型 ID 不能为空")
}
parts := strings.SplitN(id, "/", 2)
if len(parts) != 2 {
return "", "", errors.New("统一模型 ID 格式错误,缺少分隔符 \"/\"")
}
providerID = parts[0]
modelName = parts[1]
if providerID == "" {
return "", "", errors.New("provider_id 不能为空")
}
if modelName == "" {
return "", "", errors.New("model_name 不能为空")
}
if !providerIDRegex.MatchString(providerID) {
return "", "", errors.New("provider_id 仅允许字母、数字、下划线")
}
return providerID, modelName, nil
}
// FormatUnifiedModelID 将 providerID 和 modelName 组合格式化为统一模型 ID
func FormatUnifiedModelID(providerID, modelName string) string {
return providerID + "/" + modelName
}
// ValidateProviderID 校验 providerID 仅包含字母、数字、下划线,长度 1-64
func ValidateProviderID(id string) error {
if id == "" {
return errors.New("provider_id 不能为空")
}
if len(id) > 64 {
return errors.New("provider_id 长度不能超过 64 个字符")
}
if !providerIDRegex.MatchString(id) {
return errors.New("provider_id 仅允许字母、数字、下划线")
}
return nil
}
// IsValidUnifiedModelID 判断字符串是否为合法的统一模型 ID 格式
func IsValidUnifiedModelID(id string) bool {
_, _, err := ParseUnifiedModelID(id)
return err == nil
}

View File

@@ -0,0 +1,96 @@
package modelid
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseUnifiedModelID_StandardFormat(t *testing.T) {
providerID, modelName, err := ParseUnifiedModelID("openai/gpt-4")
assert.NoError(t, err)
assert.Equal(t, "openai", providerID)
assert.Equal(t, "gpt-4", modelName)
}
func TestParseUnifiedModelID_ModelNameWithSlashes(t *testing.T) {
providerID, modelName, err := ParseUnifiedModelID("azure/accounts/org-123/models/gpt-4")
assert.NoError(t, err)
assert.Equal(t, "azure", providerID)
assert.Equal(t, "accounts/org-123/models/gpt-4", modelName)
}
func TestParseUnifiedModelID_MissingSeparator(t *testing.T) {
_, _, err := ParseUnifiedModelID("gpt-4")
assert.Error(t, err)
}
func TestParseUnifiedModelID_EmptyString(t *testing.T) {
_, _, err := ParseUnifiedModelID("")
assert.Error(t, err)
}
func TestParseUnifiedModelID_OnlySeparator(t *testing.T) {
tests := []string{"/model", "provider/", "/"}
for _, tc := range tests {
_, _, err := ParseUnifiedModelID(tc)
assert.Error(t, err, "输入 %q 应返回错误", tc)
}
}
func TestParseUnifiedModelID_InvalidProviderID(t *testing.T) {
tests := []string{
"open-ai/gpt-4",
"open.ai/gpt-4",
"供应商/gpt-4",
"open ai/gpt-4",
}
for _, tc := range tests {
_, _, err := ParseUnifiedModelID(tc)
assert.Error(t, err, "providerID 含非法字符 %q 应返回错误", tc)
}
}
func TestFormatUnifiedModelID(t *testing.T) {
assert.Equal(t, "openai/gpt-4", FormatUnifiedModelID("openai", "gpt-4"))
assert.Equal(t, "anthropic/claude-3-opus", FormatUnifiedModelID("anthropic", "claude-3-opus"))
}
func TestValidateProviderID_Valid(t *testing.T) {
validIDs := []string{"openai", "deep_seek", "provider01", "OpenAI"}
for _, id := range validIDs {
assert.NoError(t, ValidateProviderID(id), "%q 应校验通过", id)
}
}
func TestValidateProviderID_InvalidChars(t *testing.T) {
invalidIDs := []string{"open-ai", "open.ai", "open ai", "供应商", "open/ai"}
for _, id := range invalidIDs {
assert.Error(t, ValidateProviderID(id), "%q 应校验失败", id)
}
}
func TestValidateProviderID_Empty(t *testing.T) {
assert.Error(t, ValidateProviderID(""))
}
func TestValidateProviderID_TooLong(t *testing.T) {
longID := strings.Repeat("a", 65)
assert.Error(t, ValidateProviderID(longID))
exactly64 := strings.Repeat("a", 64)
assert.NoError(t, ValidateProviderID(exactly64))
}
func TestIsValidUnifiedModelID(t *testing.T) {
assert.True(t, IsValidUnifiedModelID("openai/gpt-4"))
assert.True(t, IsValidUnifiedModelID("anthropic/claude-3-opus-20240229"))
assert.True(t, IsValidUnifiedModelID("azure/accounts/org/models/gpt-4"))
assert.False(t, IsValidUnifiedModelID(""))
assert.False(t, IsValidUnifiedModelID("gpt-4"))
assert.False(t, IsValidUnifiedModelID("open-ai/gpt-4"))
assert.False(t, IsValidUnifiedModelID("/model"))
assert.False(t, IsValidUnifiedModelID("provider/"))
}

View File

@@ -3,6 +3,7 @@ package tests
import ( import (
"fmt" "fmt"
"testing" "testing"
"time"
"nex/backend/internal/config" "nex/backend/internal/config"
@@ -11,26 +12,36 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// SetupTestDB initializes a temporary SQLite database with auto-migration. // SetupTestDB initializes an in-memory SQLite database with auto-migration.
// Uses :memory: mode with MaxOpenConns(1) to ensure all operations share the
// same connection, avoiding "database is closed" errors from connection pool.
// Enables foreign key constraints for SQLite.
func SetupTestDB(t *testing.T) *gorm.DB { func SetupTestDB(t *testing.T) *gorm.DB {
t.Helper() t.Helper()
dir := t.TempDir() db, err := gorm.Open(sqlite.Open(":memory:?_foreign_keys=on"), &gorm.Config{})
dsn := dir + "/test.db"
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
assert.NoError(t, err, "failed to open test database") assert.NoError(t, err, "failed to open test database")
// 限制为单连接,确保 :memory: 数据库不被连接池丢弃
sqlDB, err := db.DB()
assert.NoError(t, err, "failed to get underlying sql.DB")
sqlDB.SetMaxOpenConns(1)
sqlDB.SetConnMaxLifetime(0)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{}) err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
assert.NoError(t, err, "failed to auto-migrate test database") assert.NoError(t, err, "failed to auto-migrate test database")
return db return db
} }
// CleanupTestDB closes the database and removes the temp database file. // CleanupTestDB closes the database after a brief delay to allow async
// goroutines (e.g. stats recording) to finish.
func CleanupTestDB(t *testing.T, db *gorm.DB) { func CleanupTestDB(t *testing.T, db *gorm.DB) {
t.Helper() t.Helper()
// 等待异步 goroutine如 statsService.Record完成
time.Sleep(50 * time.Millisecond)
sqlDB, err := db.DB() sqlDB, err := db.DB()
assert.NoError(t, err, "failed to get underlying sql.DB") assert.NoError(t, err, "failed to get underlying sql.DB")
@@ -57,7 +68,8 @@ func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
} }
// CreateTestModel creates a test model and returns it. // CreateTestModel creates a test model and returns it.
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) config.Model { // Does NOT assert on error - returns the model and error for caller to verify.
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) (config.Model, error) {
t.Helper() t.Helper()
model := config.Model{ model := config.Model{
@@ -68,7 +80,5 @@ func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, mo
} }
err := db.Create(&model).Error err := db.Create(&model).Error
assert.NoError(t, err, "failed to create test model") return model, err
return model
} }

View File

@@ -14,10 +14,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai" openaiConv "nex/backend/internal/conversion/openai"
@@ -43,11 +41,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
w.Write([]byte(`{"error":"not mocked"}`)) w.Write([]byte(`{"error":"not mocked"}`))
})) }))
dir := t.TempDir() db := setupTestDB(t)
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
sqlDB, _ := db.DB() sqlDB, _ := db.DB()
if sqlDB != nil { if sqlDB != nil {
@@ -60,7 +54,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo) providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo) modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo) routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo) statsService := service.NewStatsService(statsRepo)
@@ -125,7 +119,7 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"id": modelName,
"provider_id": providerID, "provider_id": providerID,
"model_name": modelName, "model_name": modelName,
}) })
@@ -156,7 +150,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
"id": "msg_test", "id": "msg_test",
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": "claude-3-opus", "model": "anthropic_p/claude-3-opus",
"content": []map[string]any{ "content": []map[string]any{
{"type": "text", "text": "Hello from Anthropic!"}, {"type": "text", "text": "Hello from Anthropic!"},
}, },
@@ -170,11 +164,11 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
json.NewEncoder(w).Encode(resp) json.NewEncoder(w).Encode(resp)
}) })
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
// 使用 OpenAI 格式发送请求 // 使用 OpenAI 格式发送请求
openaiReq := map[string]any{ openaiReq := map[string]any{
"model": "claude-3-opus", "model": "anthropic_p/claude-3-opus",
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "user", "content": "Hello"}, {"role": "user", "content": "Hello"},
}, },
@@ -233,10 +227,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
json.NewEncoder(w).Encode(resp) json.NewEncoder(w).Encode(resp)
}) })
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
anthropicReq := map[string]any{ anthropicReq := map[string]any{
"model": "gpt-4", "model": "openai_p/gpt-4",
"max_tokens": 1024, "max_tokens": 1024,
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "user", "content": "Hello"}, {"role": "user", "content": "Hello"},
@@ -273,16 +267,18 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
body, _ := io.ReadAll(r.Body) body, _ := io.ReadAll(r.Body)
var req map[string]any var req map[string]any
json.Unmarshal(body, &req) json.Unmarshal(body, &req)
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
// 上游返回上游模型名
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`)) w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
}) })
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-4", "model": "openai_p/gpt-4", // 客户端发送统一 ID
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(reqBody)
@@ -293,7 +289,8 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "passthrough") // Smart Passthrough: 响应体中的上游模型名应被改写为统一 ID
assert.Contains(t, w.Body.String(), `"model":"openai_p/gpt-4"`)
} }
func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) { func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
@@ -302,14 +299,21 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/messages", r.URL.Path) assert.Equal(t, "/v1/messages", r.URL.Path)
body, _ := io.ReadAll(r.Body)
var req map[string]any
json.Unmarshal(body, &req)
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
assert.Equal(t, "claude-3-opus", req["model"])
// 上游返回上游模型名
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`)) w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
}) })
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
reqBody := map[string]any{ reqBody := map[string]any{
"model": "claude-3-opus", "model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
"max_tokens": 1024, "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
} }
@@ -321,7 +325,8 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "passthrough") // Smart Passthrough: 响应体中的上游模型名应被改写为统一 ID
assert.Contains(t, w.Body.String(), `"model":"anthropic_p/claude-3-opus"`)
} }
// ============ 流式转换测试 ============ // ============ 流式转换测试 ============
@@ -349,10 +354,10 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
} }
}) })
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL) createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
openaiReq := map[string]any{ openaiReq := map[string]any{
"model": "claude-3-opus", "model": "anthropic_p/claude-3-opus",
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
} }
@@ -390,10 +395,10 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
} }
}) })
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
anthropicReq := map[string]any{ anthropicReq := map[string]any{
"model": "gpt-4", "model": "openai_p/gpt-4",
"max_tokens": 1024, "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
@@ -512,7 +517,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
// 创建带 protocol 字段的 provider // 创建带 protocol 字段的 provider
providerBody := map[string]any{ providerBody := map[string]any{
"id": "test-protocol", "id": "test_protocol",
"name": "Test Protocol", "name": "Test Protocol",
"api_key": "sk-test", "api_key": "sk-test",
"base_url": "https://test.com", "base_url": "https://test.com",
@@ -533,7 +538,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
// 获取时应包含 protocol // 获取时应包含 protocol
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/providers/test-protocol", nil) req = httptest.NewRequest("GET", "/api/providers/test_protocol", nil)
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -547,7 +552,7 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) {
// 不指定 protocol默认应为 openai // 不指定 protocol默认应为 openai
providerBody := map[string]any{ providerBody := map[string]any{
"id": "default-proto", "id": "default_proto",
"name": "Default", "name": "Default",
"api_key": "sk-test", "api_key": "sk-test",
"base_url": "https://test.com", "base_url": "https://test.com",

View File

@@ -8,8 +8,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -17,10 +15,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/anthropic"
openaiConv "nex/backend/internal/conversion/openai" openaiConv "nex/backend/internal/conversion/openai"
@@ -40,25 +35,20 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
w.Write([]byte(`{"error":"not mocked"}`)) w.Write([]byte(`{"error":"not mocked"}`))
})) }))
dir, _ := os.MkdirTemp("", "e2e-test-*") db := setupTestDB(t)
db, err := gorm.Open(sqlite.Open(filepath.Join(dir, "test.db")), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
sqlDB, _ := db.DB() sqlDB, _ := db.DB()
if sqlDB != nil { if sqlDB != nil {
sqlDB.Close() sqlDB.Close()
} }
upstream.Close() upstream.Close()
os.RemoveAll(dir)
}) })
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo) providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo) modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo) routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo) statsService := service.NewStatsService(statsRepo)
@@ -105,7 +95,7 @@ func e2eCreateProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol
require.Equal(t, 201, w.Code) require.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"id": modelName, "provider_id": providerID, "model_name": modelName, "provider_id": providerID, "model_name": modelName,
}) })
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
@@ -178,10 +168,10 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
}, },
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "user", "content": "你好"}, {"role": "user", "content": "你好"},
}, },
@@ -195,7 +185,7 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
var resp map[string]any var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "chat.completion", resp["object"]) assert.Equal(t, "chat.completion", resp["object"])
assert.Equal(t, "gpt-4o", resp["model"]) assert.Equal(t, "openai_p/gpt-4o", resp["model"])
choices := resp["choices"].([]any) choices := resp["choices"].([]any)
require.Len(t, choices, 1) require.Len(t, choices, 1)
@@ -231,10 +221,10 @@ func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "system", "content": "你是编程助手"}, {"role": "system", "content": "你是编程助手"},
{"role": "user", "content": "什么是interface?"}, {"role": "user", "content": "什么是interface?"},
@@ -279,10 +269,10 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98}, "usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "user", "content": "北京天气"}, {"role": "user", "content": "北京天气"},
}, },
@@ -335,10 +325,10 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50}, "usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}}, "messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
"max_tokens": 30, "max_tokens": 30,
}) })
@@ -372,10 +362,10 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
}, },
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "o3", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "o3", "model": "openai_p/o3",
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}}, "messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -413,10 +403,10 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47}, "usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "做坏事"}}, "messages": []map[string]any{{"role": "user", "content": "做坏事"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -455,10 +445,10 @@ func TestE2E_OpenAI_Stream_Text(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true, "stream": true,
}) })
@@ -499,10 +489,10 @@ func TestE2E_OpenAI_Stream_ToolCalls(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -548,10 +538,10 @@ func TestE2E_OpenAI_Stream_WithUsage(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "hi"}}, "messages": []map[string]any{{"role": "user", "content": "hi"}},
"stream": true, "stream": true,
}) })
@@ -583,10 +573,10 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25}, "usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -599,7 +589,7 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "message", resp["type"]) assert.Equal(t, "message", resp["type"])
assert.Equal(t, "assistant", resp["role"]) assert.Equal(t, "assistant", resp["role"])
assert.Equal(t, "claude-opus-4-7", resp["model"]) assert.Equal(t, "anthropic_p/claude-opus-4-7", resp["model"])
assert.Equal(t, "end_turn", resp["stop_reason"]) assert.Equal(t, "end_turn", resp["stop_reason"])
content := resp["content"].([]any) content := resp["content"].([]any)
@@ -629,10 +619,10 @@ func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) {
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15}, "usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": "你是编程助手", "system": "你是编程助手",
"messages": []map[string]any{{"role": "user", "content": "什么是递归?"}}, "messages": []map[string]any{{"role": "user", "content": "什么是递归?"}},
}) })
@@ -658,10 +648,10 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42}, "usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
@@ -704,10 +694,10 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280}, "usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 4096, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}}, "messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
"thinking": map[string]any{"type": "enabled", "budget_tokens": 2048}, "thinking": map[string]any{"type": "enabled", "budget_tokens": 2048},
}) })
@@ -736,10 +726,10 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) {
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 20, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 20,
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}}, "messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -764,10 +754,10 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop_sequences": []string{"5"}, "stop_sequences": []string{"5"},
}) })
@@ -800,10 +790,10 @@ func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) {
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5}, "usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
"metadata": map[string]any{"user_id": "user_12345"}, "metadata": map[string]any{"user_id": "user_12345"},
}) })
@@ -829,10 +819,10 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
}, },
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}}, "system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
}) })
@@ -874,10 +864,10 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "你好"}}, "messages": []map[string]any{{"role": "user", "content": "你好"}},
"stream": true, "stream": true,
}) })
@@ -921,10 +911,10 @@ func TestE2E_Anthropic_Stream_Thinking(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 4096, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
"messages": []map[string]any{{"role": "user", "content": "1+1=?"}}, "messages": []map[string]any{{"role": "user", "content": "1+1=?"}},
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024}, "thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
"stream": true, "stream": true,
@@ -970,10 +960,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_RequestFormat(t *testing.T) {
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5}, "usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1011,10 +1001,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_RequestFormat(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4", "max_tokens": 1024, "model": "openai_p/gpt-4", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1052,10 +1042,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
}) })
@@ -1092,10 +1082,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4", "max_tokens": 1024, "model": "openai_p/gpt-4", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "Hello"}}, "messages": []map[string]any{{"role": "user", "content": "Hello"}},
"stream": true, "stream": true,
}) })
@@ -1130,10 +1120,10 @@ func TestE2E_OpenAI_ErrorResponse(t *testing.T) {
}, },
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "nonexistent", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "nonexistent", "model": "openai_p/nonexistent",
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1157,10 +1147,10 @@ func TestE2E_Anthropic_ErrorResponse(t *testing.T) {
}, },
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1203,10 +1193,10 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 36, "total_tokens": 136}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 36, "total_tokens": 136},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1255,10 +1245,10 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18}, "usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}}, "messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
"stop": []string{"5"}, "stop": []string{"5"},
}) })
@@ -1293,10 +1283,10 @@ func TestE2E_OpenAI_NonStream_ContentFilter(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 8, "completion_tokens": 0, "total_tokens": 8}, "usage": map[string]any{"prompt_tokens": 8, "completion_tokens": 0, "total_tokens": 8},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "危险内容"}}, "messages": []map[string]any{{"role": "user", "content": "危险内容"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1325,10 +1315,10 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) {
"usage": map[string]any{"input_tokens": 200, "output_tokens": 84}, "usage": map[string]any{"input_tokens": 200, "output_tokens": 84},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
@@ -1374,10 +1364,10 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30}, "usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "现在几点了?"}}, "messages": []map[string]any{{"role": "user", "content": "现在几点了?"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"name": "get_time", "description": "获取当前时间", "name": "get_time", "description": "获取当前时间",
@@ -1417,10 +1407,10 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10}, "usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"system": []map[string]any{ "system": []map[string]any{
{"type": "text", "text": "你是编程助手。"}, {"type": "text", "text": "你是编程助手。"},
{"type": "text", "text": "请用中文回答。"}, {"type": "text", "text": "请用中文回答。"},
@@ -1454,10 +1444,10 @@ func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) {
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "user", "content": "北京天气"}, {"role": "user", "content": "北京天气"},
{"role": "assistant", "content": []map[string]any{ {"role": "assistant", "content": []map[string]any{
@@ -1507,10 +1497,10 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
@@ -1561,10 +1551,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_NonStream_ToolCalls(t *testing.T) {
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30}, "usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1613,10 +1603,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_NonStream_Thinking(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150}, "usage": map[string]any{"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4", "max_tokens": 4096, "model": "openai_p/gpt-4", "max_tokens": 4096,
"messages": []map[string]any{{"role": "user", "content": "宇宙的答案"}}, "messages": []map[string]any{{"role": "user", "content": "宇宙的答案"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1643,10 +1633,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
"usage": map[string]any{"input_tokens": 10, "output_tokens": 20}, "usage": map[string]any{"input_tokens": 10, "output_tokens": 20},
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "长文"}}, "messages": []map[string]any{{"role": "user", "content": "长文"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1685,10 +1675,10 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, "usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{ "messages": []map[string]any{
{"role": "user", "content": "北京天气"}, {"role": "user", "content": "北京天气"},
{"role": "assistant", "content": nil, "tool_calls": []map[string]any{{ {"role": "assistant", "content": nil, "tool_calls": []map[string]any{{
@@ -1732,10 +1722,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-model", "model": "anthropic_p/claude-model",
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"type": "function", "type": "function",
@@ -1781,10 +1771,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream_ToolCalls(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4", "max_tokens": 1024, "model": "openai_p/gpt-4", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "北京天气"}}, "messages": []map[string]any{{"role": "user", "content": "北京天气"}},
"tools": []map[string]any{{ "tools": []map[string]any{{
"name": "get_weather", "description": "获取天气", "name": "get_weather", "description": "获取天气",
@@ -1819,10 +1809,10 @@ func TestE2E_OpenAI_Upstream5xx_ErrorPassthrough(t *testing.T) {
}, },
}) })
}) })
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL) e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "gpt-4o", "model": "openai_p/gpt-4o",
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1851,10 +1841,10 @@ func TestE2E_Anthropic_Upstream5xx_ErrorPassthrough(t *testing.T) {
}, },
}) })
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1889,10 +1879,10 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
}) })
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL) e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
body, _ := json.Marshal(map[string]any{ body, _ := json.Marshal(map[string]any{
"model": "claude-opus-4-7", "max_tokens": 1024, "model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
"messages": []map[string]any{{"role": "user", "content": "test"}}, "messages": []map[string]any{{"role": "user", "content": "test"}},
"stream": true, "stream": true,
}) })

View File

@@ -9,11 +9,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/handler" "nex/backend/internal/handler"
"nex/backend/internal/handler/middleware" "nex/backend/internal/handler/middleware"
@@ -27,23 +24,13 @@ func init() {
func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) { func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
t.Helper() t.Helper()
dir := t.TempDir() db := setupTestDB(t)
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
providerRepo := repository.NewProviderRepository(db) providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db) modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo) providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo) modelService := service.NewModelService(modelRepo, providerRepo)
_ = service.NewRoutingService(modelRepo, providerRepo) _ = service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo) statsService := service.NewStatsService(statsRepo)
@@ -97,13 +84,16 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
// 2. 创建 Model // 2. 创建 Model
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"id": "gpt4", "provider_id": "openai", "model_name": "gpt-4", "provider_id": "openai", "model_name": "gpt-4",
}) })
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel)
assert.NotEmpty(t, createdModel.ID)
// 3. 列出 Provider // 3. 列出 Provider
w = httptest.NewRecorder() w = httptest.NewRecorder()
@@ -135,7 +125,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
// 6. 删除 Model // 6. 删除 Model
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest("DELETE", "/api/models/gpt4", nil) req = httptest.NewRequest("DELETE", "/api/models/"+createdModel.ID, nil)
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 204, w.Code) assert.Equal(t, 204, w.Code)
@@ -160,17 +150,19 @@ func TestAnthropic_ModelCreation(t *testing.T) {
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"id": "claude3", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229",
}) })
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
var createdModel domain.Model
json.Unmarshal(w.Body.Bytes(), &createdModel)
// 验证创建成功 // 验证创建成功
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/models/claude3", nil) req = httptest.NewRequest("GET", "/api/models/"+createdModel.ID, nil)
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
@@ -188,7 +180,7 @@ func TestStats_RecordingAndQuery(t *testing.T) {
r.ServeHTTP(w, req) r.ServeHTTP(w, req)
modelBody, _ := json.Marshal(map[string]string{ modelBody, _ := json.Marshal(map[string]string{
"id": "m1", "provider_id": "p1", "model_name": "gpt-4", "provider_id": "p1", "model_name": "gpt-4",
}) })
w = httptest.NewRecorder() w = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody)) req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))

View File

@@ -0,0 +1,37 @@
package integration
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
)
// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。
// 使用 MaxOpenConns(1) 确保 :memory: 模式不会被连接池丢弃。
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
sqlDB, err := db.DB()
require.NoError(t, err)
sqlDB.SetMaxOpenConns(1)
sqlDB.SetConnMaxLifetime(0)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() {
// 等待异步 goroutine如 statsService.Record完成
time.Sleep(50 * time.Millisecond)
sqlDB.Close()
})
return db
}

View File

@@ -0,0 +1,79 @@
package tests
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
)
func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) {
db := SetupTestDB(t)
defer CleanupTestDB(t, db)
// 创建供应商
_ = CreateTestProvider(t, db, "openai")
// 创建模型使用 UUID 作为 id
model, err := CreateTestModel(t, db, "550e8400-e29b-41d4-a716-446655440000", "openai", "gpt-4")
require.NoError(t, err)
// 通过 UUID 查询
var result config.Model
require.NoError(t, db.First(&result, "id = ?", model.ID).Error)
assert.Equal(t, "gpt-4", result.ModelName)
}
func TestMigration_UniqueProviderModelConstraint(t *testing.T) {
db := SetupTestDB(t)
defer CleanupTestDB(t, db)
// 创建供应商
_ = CreateTestProvider(t, db, "openai")
// 创建第一个模型
_, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4")
require.NoError(t, err)
// 尝试创建相同 (provider_id, model_name) 的模型应失败
_, err = CreateTestModel(t, db, "uuid-2", "openai", "gpt-4")
assert.Error(t, err, "UNIQUE(provider_id, model_name) 约束应阻止重复")
// 不同 provider_id 下相同 model_name 应成功
_ = CreateTestProvider(t, db, "anthropic")
_, err = CreateTestModel(t, db, "uuid-3", "anthropic", "gpt-4")
require.NoError(t, err, "不同 provider_id 下相同 model_name 应允许")
}
func TestMigration_CascadeDelete(t *testing.T) {
db := SetupTestDB(t)
defer CleanupTestDB(t, db)
// 创建供应商和模型
provider := CreateTestProvider(t, db, "openai")
_, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4")
require.NoError(t, err)
// 删除供应商应级联删除模型
require.NoError(t, db.Delete(&provider).Error)
var count int64
db.Model(&config.Model{}).Where("provider_id = ?", "openai").Count(&count)
assert.Equal(t, int64(0), count, "删除供应商后其模型应被级联删除")
}
func TestMigration_ModelDefaultEnabled(t *testing.T) {
db := SetupTestDB(t)
defer CleanupTestDB(t, db)
_ = CreateTestProvider(t, db, "openai")
// 创建模型不指定 enabled应默认为 true
_, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4")
require.NoError(t, err)
var result config.Model
require.NoError(t, db.First(&result, "id = ?", "uuid-1").Error)
assert.True(t, result.Enabled, "enabled 字段默认应为 true")
}

View File

@@ -1,2 +0,0 @@
schema: spec-driven
created: 2026-04-20

View File

@@ -1,137 +0,0 @@
## Context
Nex 是一个 AI 网关,屏蔽多个 AI 供应商OpenAI、Anthropic 等)的差异,提供统一的 API 接口。当前后端直接透传上游供应商的原始模型名称(如 `gpt-4`),通过 `models` 表的 `model_name` 字段路由。`models` 表的 `id` 字段当前语义是用户自定义标识符,与上游模型名 `model_name` 之间没有明确的职责分离。
当前架构:
- `ProxyHandler` 从请求体中提取 `model` 字段 → `RoutingService.Route(modelName)``model_name` 查询
- `GET /v1/models` 直接透传到第一个供应商的上游接口
- `GET /v1/models/{id}` 直接透传到上游
- `TargetProvider.ModelName` 在 encoder 中覆盖请求体的 `model` 字段
## Goals / Non-Goals
**Goals:**
- 定义统一模型 ID 格式 `provider_id/model_name`,全局唯一标识一个模型
- 拦截 `/v1/models``/v1/models/{unified_id}` 接口,从数据库聚合返回,不再透传上游
- 所有代理接口Chat、Embeddings、Rerank使用统一模型 ID 路由,响应中 `model` 字段覆写为统一 ID
- `models.id` 改为 UUID内部标识`models.model_name` 存储上游供应商的模型名称
- `provider_id` 约束为 `[a-zA-Z0-9_]+`,防止特殊字符影响 URL 和 JSON 交互
- 保持协议无关、供应商无关的设计
**Non-Goals:**
- 不支持供应商别名或模型别名
- 不做上游模型列表自动同步(管理员手动配置可见模型)
- 不适配前端(后续统一适配)
## Decisions
### D1: 统一模型 ID 格式 — `provider_id/model_name`
格式: `{provider_id}/{model_name}`,例如 `openai/gpt-4``anthropic/claude-3-opus-20240229`
- 使用 `strings.SplitN(id, "/", 2)` 解析,只在第一个 `/` 处分割
- `provider_id` 约束为 `[a-zA-Z0-9_]+`,保证不含 `/`,解析安全
- `model_name`(上游模型名)不受字符约束,因为它不出现在管理 API 的 URL 主键中
选择此格式而非 `provider_id:model_name`(冒号分隔)的原因:斜杠在 JSON 字符串中天然安全,且在 URL 路径中语义清晰(`/v1/models/openai/gpt-4`),更符合 REST 风格。
### D2: models 表 schema 变更
```
旧: id(TEXT PK, 用户自定义), provider_id, model_name(上游模型名), enabled, created_at
新: id(UUID PK, 自动生成), provider_id, model_name(上游模型名), enabled, created_at
UNIQUE(provider_id, model_name)
```
关键语义变化:
- `id` 从用户自定义标识符变为 UUID 内部主键(自动生成),用于管理接口 CRUD
- `model_name` 语义不变,始终存储上游供应商的模型名称,发给上游的实际值
- 新增联合唯一约束 `UNIQUE(provider_id, model_name)` 保证同一供应商内模型不重复
选择保留 `id` 作为 PK 而非使用 `(provider_id, model_name)` 联合主键的原因:上游模型名可能含 `/` 等特殊字符(如 Azure OpenAI 的 deployment 路径),不适合作为管理接口的 URL 参数。`id` 为 UUID 可以避免所有特殊字符问题。
### D3: Models/ModelInfo 接口本地聚合
`GET /v1/models` 从数据库查询所有 `enabled` 的模型JOIN providers组装为 `CanonicalModelList``ID` 字段使用统一模型 ID通过客户端协议的 adapter 编码返回。不请求上游。
`GET /v1/models/{provider_id}/{model_name}` 从 URL 提取统一模型 ID解析后查询数据库组装为 `CanonicalModelInfo` 返回。不请求上游。
选择纯 DB 聚合而非实时查询上游的原因:
1. 管理员通过 `/api/models` 控制哪些模型对用户可见,网关的意义在于控制可见性
2. 响应速度快,不依赖上游可用性
3. 符合当前架构中管理员手动配置 provider 和 model 的设计哲学
### D4: 跨协议响应 model 字段覆写
跨协议场景下,上游返回的响应经过 decode → encode 全量转换。上游响应中的 `model` 字段是原生模型名(如 `gpt-4`),需要在返回给客户端前覆写为统一模型 ID。
实现位置:`ConversionEngine.ConvertHttpResponse` 新增 `modelOverride string` 参数。在解码上游响应到 canonical 后、编码客户端响应前,将 `canonical.Model` 设为 `modelOverride`。流式场景同理,`CreateStreamConverter` 同样接收 `modelOverride` 参数。
此方案仅在跨协议转换路径使用。选择在 canonical 层面处理的原因:
1. 跨协议必须全量 decode → encodecanonical 的 Model 字段天然可覆写
2. 不侵入各协议 adapter 的实现
3. 与 Smart Passthrough 互补——跨协议不可保真canonical 覆写是自然的
### D5: ProtocolAdapter 接口扩展
`ProtocolAdapter` 接口新增四个方法,将所有协议相关的 model 字段知识归属到 adapter
1. `ExtractUnifiedModelID(nativePath string) (string, error)` — 从路径中提取统一模型 ID
2. `ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)` — 从请求体中提取 model 值(所有流程复用,替代 handler 层硬编码的 `extractModelName`
3. `RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)` — 最小化 JSON 改写请求体中的 model 字段Smart Passthrough 请求方向)
4. `RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)` — 最小化 JSON 改写响应体中的 model 字段Smart Passthrough 响应方向)
拆分请求/响应方向的原因:请求体和响应体的 JSON 结构可能不同model 字段的位置可能不同(当前 OpenAI/Anthropic 协议碰巧都在顶层 `"model"`,但未来协议不一定)。拆分后 adapter 各自独立实现,各自按 ifaceType 分派。
`ExtractModelName` 和两个 `Rewrite*` 方法均接收 `InterfaceType` 参数,因为不同接口类型的请求体/响应体结构可能不同adapter 按 ifaceType 分派具体的定位和改写逻辑。
对于 `isModelInfoPath` 的调整:允许 suffix 中包含 `/`,因为统一模型 ID 格式为 `provider_id/model_name`
将此方法放在适配器接口而非 handler 中通用实现的原因:不同协议的模型详情路径格式和请求体结构可能不同,各自拥有独立演进能力。
### D6: provider_id 字符集约束
创建供应商时校验 `id` 字段必须匹配 `^[a-zA-Z0-9_]+$`,长度 1-64。
选择严格限制而非仅排除 `/` 的原因:统一模型 ID 出现在 URL 路径和 JSON 中,`?``#``&``=` 等字符会在 URL 中引起解析问题。限制为字母数字下划线后URL 中永远安全,不需要编码。
### D7: pkg/modelid 工具包
新增 `pkg/modelid` 包,提供:
- `ParseUnifiedModelID(id string) (providerID, modelName string, error)` — 解析
- `FormatUnifiedModelID(providerID, modelName string) string` — 格式化
- `ValidateProviderID(id string) error` — 校验供应商 ID
- `IsValidUnifiedModelID(id string) bool` — 校验统一模型 ID
使用标准库 `strings.SplitN``regexp` 实现,不引入新依赖。
### D8: 同协议 Smart Passthrough
当前同协议透传将请求体原样转发,跳过 decode → encode保持参数完全保真。但统一模型 ID 要求改写 model 字段,原样透传无法满足。
**Smart Passthrough**:保留同协议透传的保真优势,通过 `json.RawMessage` 做最小化改写。
实现方式adapter 的 `RewriteRequestModelName``RewriteResponseModelName` 方法各自解析 JSON 为 `map[string]json.RawMessage`,只替换 model 字段的 value其余字段保留原始 bytes不经过任何类型转换。参数保真、不丢精度、不改字段顺序。
各接口类型策略:
- Chat/Embedding/Rerank同协议Smart Passthrough — 请求改写 model统一 ID → 上游名),响应改写 model上游名 → 统一 ID
- Chat/Embedding/Rerank跨协议全量 decode → encode + modelOverride
- Models/ModelInfo本地数据库聚合不请求上游
- Passthrough未知路径原样透传不改写 model
选择让 adapter 拥有完整协议知识(而非通用 json hack的原因
1. 不同协议的 model 字段位置可能不同adapter 按 InterfaceType 分派
2. 请求和响应的 model 字段位置可能不同,拆分 RewriteRequestModelName/RewriteResponseModelName 各自独立实现
3. adapter 内部实现 `ExtractModelName` 和两个 `Rewrite*` 方法可共享同一份"model 在哪"的定位逻辑
4. 所有流程复用 `ExtractModelName`,同协议额外复用 `RewriteRequestModelName` + `RewriteResponseModelName`
## Risks / Trade-offs
- **[BREAKING CHANGE]** 代理接口 model 字段格式变更,现有客户端必须适配 → 统一 ID 格式简单直观,服务尚未上线无旧客户端
- **[联合唯一约束]** 同一供应商下相同 model_name 不允许重复 → 这是正确的行为,语义上就不应该重复
- **[model_name 含特殊字符]** 上游模型名可能含 `/`(如 Azure deployment 路径)→ 解析用 `SplitN("/", 2)` 安全,管理接口用 `id` 定位不受影响,代理接口中统一 ID 出现在 JSON body 和 URL 路径中均安全
- **[流式响应覆写]** 同协议流式场景需逐 SSE chunk 调用 RewriteResponseModelName → 每个 chunk 多一次轻量 JSON 解析,用 json.RawMessage 保证开销极小
## Open Questions
无。所有关键决策已在探索阶段确认。

View File

@@ -1,41 +0,0 @@
## Why
当前网关直接透传上游供应商的原始模型名称(如 `gpt-4``claude-3-opus`),无法在多供应商场景下唯一标识一个模型。不同供应商可能存在同名模型,客户端无法区分应路由到哪个供应商。网关作为屏蔽供应商差异的统一入口,需要定义自有的模型标识体系,让客户端通过统一的 model ID 访问任意供应商的模型,同时拦截 `/v1/models` 等模型查询接口,聚合所有供应商的模型信息返回。
## What Changes
- **BREAKING**: 引入统一模型 ID 格式 `provider_id/model_name`(如 `openai/gpt-4`所有代理接口Chat、Embeddings、Rerank`model` 字段必须使用此格式
- **BREAKING**: `models` 表主键 `id` 改为 UUID 自动生成(不再由用户提供),`model_name` 字段语义保持不变(存储上游供应商模型名称),新增 `UNIQUE(provider_id, model_name)` 联合唯一约束
- **BREAKING**: `provider_id` 限制为 `[a-zA-Z0-9_]+` 字符集,禁止特殊字符
- `GET /v1/models` 改为从数据库聚合返回所有已启用模型,不再透传到上游供应商
- `GET /v1/models/{unified_id}` 改为从数据库查询返回模型详情,不再透传到上游供应商
- 同协议透传改为 Smart Passthrough通过 `json.RawMessage` 最小化改写 model 字段,保持其余参数完全保真
- 跨协议转换路径:通过 canonical 层面 modelOverride 参数覆写响应 model 字段
- 管理 API (`/api/models`) 请求体字段适配,响应中新增 `unified_id` 字段
- 新增 `pkg/modelid` 工具包,提供统一模型 ID 的解析、格式化、校验
- ProtocolAdapter 接口新增 `ExtractUnifiedModelID``ExtractModelName``RewriteRequestModelName``RewriteResponseModelName` 方法,协议无关地处理 model 字段
## Capabilities
### New Capabilities
- `unified-model-id`: 统一模型 ID 的解析、格式化、校验工具包,以及 `provider_id` 字符集约束
### Modified Capabilities
- `model-management`: 模型表结构调整id 改 UUID 自动生成、新增联合唯一约束CRUD 接口字段变更(创建不再提供 id
- `provider-management`: provider_id 创建时增加字符集校验(`[a-zA-Z0-9_]+`
- `unified-proxy-handler`: 统一模型 ID 解析路由、Models/ModelInfo 接口改为本地聚合、同协议 Smart Passthrough、跨协议 modelOverride 覆写
- `conversion-engine`: 跨协议场景下 ConvertHttpResponse 支持 model 覆写参数
- `protocol-adapter-openai`: isModelInfoPath 适配含 `/` 路径、新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName
- `protocol-adapter-anthropic`: isModelInfoPath 适配含 `/` 路径、新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName
- `request-validation`: provider_id 字符集校验规则、模型创建校验适配
- `database-migration`: models 表 schema 变更迁移DROP + CREATE 重建)
- `usage-statistics`: 明确统计记录使用 providerID + modelName 的上游模型名
## Impact
- **数据库**: models 表 schema 变更DROP + CREATE 重建)
- **API 兼容性**: 代理接口 model 字段格式为 BREAKING CHANGE需客户端适配
- **管理 API**: `/api/models` 请求体变更(创建不再提供 id自动生成 UUID响应新增 unified_id 字段
- **代码模块**: domain、repository、service、handler、conversion、adapter 层均有改动
- **测试**: routing service、proxy handler、adapter、model handler 需要新增/更新测试
- **前端**: 本次变更不涉及前端适配,前端后续统一适配

View File

@@ -1,49 +0,0 @@
## MODIFIED Requirements
### Requirement: 跨协议响应转换支持 model 覆写
ConversionEngine SHALL 在跨协议响应转换时支持 model 字段覆写。
#### Scenario: ConvertHttpResponse 接收 modelOverride 参数
- **WHEN** 调用 `ConvertHttpResponse` 时传入 `modelOverride` 参数(跨协议场景,非空字符串)
- **THEN** SHALL 在解码上游响应到 canonical 后,将 `Model` 字段设为 `modelOverride`
- **THEN** SHALL 使用覆写后的 canonical 编码为客户端协议格式
#### Scenario: modelOverride 为空
- **WHEN** 调用 `ConvertHttpResponse``modelOverride` 为空字符串
- **THEN** SHALL NOT 覆写 canonical 的 Model 字段,保持上游原始值
#### Scenario: Chat 响应 model 覆写
- **WHEN** 跨协议转换 Chat 类型响应且 `modelOverride` 非空
- **THEN** `CanonicalResponse.Model` SHALL 被设为 `modelOverride`
#### Scenario: Embedding 响应 model 覆写
- **WHEN** 跨协议转换 Embedding 类型响应且 `modelOverride` 非空
- **THEN** `CanonicalEmbeddingResponse.Model` SHALL 被设为 `modelOverride`
#### Scenario: Rerank 响应 model 覆写
- **WHEN** 跨协议转换 Rerank 类型响应且 `modelOverride` 非空
- **THEN** `CanonicalRerankResponse.Model` SHALL 被设为 `modelOverride`
### Requirement: 跨协议流式转换支持 model 覆写
ConversionEngine SHALL 在跨协议流式转换时支持 model 字段覆写。
#### Scenario: CreateStreamConverter 接收 modelOverride 参数
- **WHEN** 调用 `CreateStreamConverter` 时传入 `modelOverride` 参数(跨协议场景)
- **THEN** SHALL 在流式 canonical 事件中将 `Model` 字段设为 `modelOverride`
### Requirement: TargetProvider 字段语义
TargetProvider 的 ModelName 字段 SHALL 存储上游供应商的模型名称(即 `model_name` 字段值),语义保持不变。
#### Scenario: encoder 使用 TargetProvider.ModelName
- **WHEN** 协议适配器编码请求时
- **THEN** SHALL 使用 `TargetProvider.ModelName` 作为发给上游的 `model` 字段值(值为路由结果中的 model_name

View File

@@ -1,13 +0,0 @@
## MODIFIED Requirements
### Requirement: models 表 schema 变更
系统 SHALL 通过迁移脚本重建 models 表结构(服务未上线,无需考虑数据迁移)。
#### Scenario: 迁移后 models 表结构
- **WHEN** 执行迁移
- **THEN** SHALL 先 DROP 已有的 models 表(无旧数据)
- **THEN** SHALL CREATE 新的 models 表包含字段idTEXT PRIMARY KEY、provider_idTEXT NOT NULL、model_nameTEXT NOT NULL、enabledINTEGER DEFAULT 1、created_atDATETIME
- **THEN** SHALL 存在 UNIQUE(provider_id, model_name) 约束
- **THEN** SHALL 存在 FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE

View File

@@ -1,105 +0,0 @@
## MODIFIED Requirements
### Requirement: 创建模型配置
网关 SHALL 允许为供应商创建新的模型配置。
#### Scenario: 使用有效数据创建模型
- **WHEN** 向 `/api/models` 发送 POST 请求携带有效的模型数据provider_id, model_name不提供 id 字段
- **THEN** 网关 SHALL 自动生成 UUID 作为模型 id
- **THEN** 网关 SHALL 在数据库中创建新的模型记录
- **THEN** 网关 SHALL 返回创建的模型,状态码为 201
- **THEN** 模型 SHALL 默认启用
- **THEN** 返回的模型 SHALL 包含 `unified_id` 字段,值为 `{provider_id}/{model_name}`
#### Scenario: 使用不存在的供应商创建模型
- **WHEN** 向 `/api/models` 发送 POST 请求,携带不存在的 provider_id
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示供应商不存在
#### Scenario: 创建重复模型
- **WHEN** 向 `/api/models` 发送 POST 请求,携带已存在的 provider_id + model_name 组合
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
- **THEN** 错误 SHALL 指示该供应商下已存在相同模型
### Requirement: 列出所有模型
网关 SHALL 允许获取所有模型配置。
#### Scenario: 成功列出模型
- **WHEN** 向 `/api/models` 发送 GET 请求
- **THEN** 网关 SHALL 返回所有模型的列表
- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, unified_id, enabled, created_at
**变更说明:** 响应新增 unified_id 字段,移除旧语义的 id 自定义输入。
### Requirement: 按供应商列出模型
网关 SHALL 允许获取特定供应商的模型。
#### Scenario: 列出存在供应商的模型
- **WHEN** 向 `/api/models?provider_id=<provider_id>` 发送 GET 请求
- **THEN** 网关 SHALL 返回指定供应商的模型列表
- **THEN** 每个模型 SHALL 包含 unified_id 字段
### Requirement: 更新模型配置
网关 SHALL 允许更新现有模型配置。
#### Scenario: 使用有效数据更新模型
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据
- **THEN** 网关 SHALL 更新数据库中的模型记录
- **THEN** 网关 SHALL 返回更新后的模型
- **THEN** 返回的模型 SHALL 包含更新后的 unified_id
#### Scenario: 更新模型为重复组合
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,更新 provider_id 或 model_name 导致与已有记录重复
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
### Requirement: 删除模型配置
网关 SHALL 允许删除模型配置。
#### Scenario: 删除存在的模型
- **WHEN** 向 `/api/models/:id` 发送 DELETE 请求,携带有效的模型 ID
- **THEN** 网关 SHALL 删除模型记录
- **THEN** 网关 SHALL 返回状态码 204 (No Content)
### Requirement: 使用 service 层处理业务逻辑
Handler SHALL 通过 ModelService 处理业务逻辑。
#### Scenario: 调用 service 方法
- **WHEN** handler 收到请求
- **THEN** SHALL 调用对应的 ModelService 方法Create、Get、List、Update、Delete
- **THEN** SHALL 使用 domain.Model 类型
- **THEN** Create 时 SHALL 调用 `uuid.New()` 生成 id
#### Scenario: 供应商验证和唯一性校验
- **WHEN** 创建或更新模型
- **THEN** SHALL 在 service 层验证供应商存在
- **THEN** SHALL 在 service 层验证 provider_id + model_name 联合唯一
### Requirement: 使用 repository 层访问数据
Service SHALL 通过 ModelRepository 访问数据。
#### Scenario: 联合查询
- **WHEN** service 需要按 provider 和 model_name 查询模型
- **THEN** SHALL 调用 `FindByProviderAndModelName(providerID, modelName)` 方法
#### Scenario: 查询所有启用模型
- **WHEN** proxy handler 需要聚合模型列表
- **THEN** SHALL 调用 `ListEnabled()` 方法,返回所有 enabled 的模型(关联 enabled 的供应商)

View File

@@ -1,71 +0,0 @@
## MODIFIED Requirements
### Requirement: 模型详情路径识别
Anthropic 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
#### Scenario: 含斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/anthropic/claude-3-opus`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 含多段斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/azure/deployments/gpt-4`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 模型列表路径不受影响
- **WHEN** 路径为 `/v1/models`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
### Requirement: 提取统一模型 ID
Anthropic 适配器 SHALL 从路径中提取统一模型 ID。
#### Scenario: 标准路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/anthropic/claude-3-opus")`
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
#### Scenario: 复杂路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
#### Scenario: 非模型详情路径
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
- **THEN** SHALL 返回错误
### Requirement: 从请求体提取 model
Anthropic 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
#### Scenario: Chat 请求提取 model
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`body 为 `{"model":"anthropic/claude-3-opus","messages":[...]}`
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
#### Scenario: 无 model 字段
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`body 中不含 model 字段
- **THEN** SHALL 返回空字符串,不返回错误
### Requirement: 最小化改写请求体 model 字段
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
- **WHEN** 调用 `RewriteRequestModelName(body, "claude-3-opus-20240229", InterfaceTypeChat)`
- **THEN** SHALL 将请求体中 model 字段替换为 `"claude-3-opus-20240229"`,其余字段原样保留
### Requirement: 最小化改写响应体 model 字段
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
#### Scenario: 响应体 model 改写(上游名 → 统一 ID
- **WHEN** 调用 `RewriteResponseModelName(body, "anthropic/claude-3-opus", InterfaceTypeChat)`
- **THEN** SHALL 将响应体中 model 字段替换为 `"anthropic/claude-3-opus"`,其余字段原样保留

View File

@@ -1,89 +0,0 @@
## MODIFIED Requirements
### Requirement: 模型详情路径识别
OpenAI 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
#### Scenario: 含斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/openai/gpt-4`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 含多段斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/azure/accounts/org-123/models/gpt-4`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 模型列表路径不受影响
- **WHEN** 路径为 `/v1/models`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
### Requirement: 提取统一模型 ID
OpenAI 适配器 SHALL 从路径中提取统一模型 ID。
#### Scenario: 标准路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/openai/gpt-4")`
- **THEN** SHALL 返回 `"openai/gpt-4"`
#### Scenario: 复杂路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
#### Scenario: 非模型详情路径
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
- **THEN** SHALL 返回错误
### Requirement: 从请求体提取 model
OpenAI 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
#### Scenario: Chat 请求提取 model
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`body 为 `{"model":"openai/gpt-4","messages":[...]}`
- **THEN** SHALL 返回 `"openai/gpt-4"`
#### Scenario: Embedding 请求提取 model
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeEmbeddings)`body 为 `{"model":"openai/text-embedding-3","input":"text"}`
- **THEN** SHALL 返回 `"openai/text-embedding-3"`
#### Scenario: 无 model 字段
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`body 中不含 model 字段
- **THEN** SHALL 返回空字符串,不返回错误
### Requirement: 最小化改写请求体 model 字段
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeChat)`body 为 `{"model":"openai/gpt-4","messages":[...],"some_param":"value"}`
- **THEN** SHALL 返回 `{"model":"gpt-4","messages":[...],"some_param":"value"}`
- **THEN** 除 model 外的字段 SHALL 保持原始 bytes 不变
#### Scenario: 不同 InterfaceType 的请求改写
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeEmbeddings)`
- **THEN** SHALL 按 Embedding 接口的请求体 model 字段位置进行改写
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeRerank)`
- **THEN** SHALL 按 Rerank 接口的请求体 model 字段位置进行改写
### Requirement: 最小化改写响应体 model 字段
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
#### Scenario: 响应体 model 改写(上游名 → 统一 ID
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeChat)`body 为上游 Chat 响应
- **THEN** SHALL 将 model 字段替换为 `"openai/gpt-4"`,其余字段原样保留
#### Scenario: 不同 InterfaceType 的响应改写
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeEmbeddings)`
- **THEN** SHALL 按 Embedding 接口的响应体 model 字段位置进行改写

View File

@@ -1,35 +0,0 @@
## MODIFIED Requirements
### Requirement: 创建供应商配置
网关 SHALL 允许通过管理 API 创建新的供应商配置。
#### Scenario: 使用有效数据创建供应商
- **WHEN** 向 `/api/providers` 发送 POST 请求携带有效的供应商数据id, name, api_key, base_url, protocol
- **THEN** 网关 SHALL 在数据库中创建新的供应商记录
- **THEN** 网关 SHALL 返回创建的供应商,状态码为 201
- **THEN** 供应商 SHALL 默认启用
- **THEN** protocol 字段 SHALL 默认为 "openai"
#### Scenario: 使用重复 ID 创建供应商
- **WHEN** 向 `/api/providers` 发送 POST 请求,携带已存在的 ID
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
#### Scenario: 创建供应商时缺少必需字段
- **WHEN** 向 `/api/providers` 发送 POST 请求缺少必需字段id, name, api_key 或 base_url
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示缺少哪些字段
#### Scenario: 创建供应商时 ID 包含非法字符
- **WHEN** 向 `/api/providers` 发送 POST 请求id 包含非 `[a-zA-Z0-9_]` 字符
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示 id 仅允许字母、数字、下划线
#### Scenario: 创建供应商时 ID 过长
- **WHEN** 向 `/api/providers` 发送 POST 请求id 长度超过 64
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)

View File

@@ -1,39 +0,0 @@
## MODIFIED Requirements
### Requirement: 供应商 ID 校验
创建供应商时SHALL 对 `id` 字段进行字符集校验。
#### Scenario: 合法字符集
- **WHEN** 创建供应商id 仅包含 `[a-zA-Z0-9_]` 字符
- **THEN** SHALL 校验通过
#### Scenario: 非法字符
- **WHEN** 创建供应商id 包含 `-``.``/`、空格、中文等非 `[a-zA-Z0-9_]` 字符
- **THEN** SHALL 返回 400 错误
#### Scenario: 长度限制
- **WHEN** 创建供应商id 长度超过 64
- **THEN** SHALL 返回 400 错误
### Requirement: 模型创建校验
创建模型时SHALL 对 `provider_id` + `model_name` 进行联合唯一性校验。
#### Scenario: 正常创建
- **WHEN** 创建模型provider_id 存在且 provider_id + model_name 组合唯一
- **THEN** SHALL 校验通过
#### Scenario: 联合唯一冲突
- **WHEN** 创建模型provider_id + model_name 组合已存在
- **THEN** SHALL 返回 409 错误
#### Scenario: model_name 为空
- **WHEN** 创建模型,未提供 model_name
- **THEN** SHALL 返回 400 错误

View File

@@ -1,123 +0,0 @@
## MODIFIED Requirements
### Requirement: 代理请求路由
ProxyHandler SHALL 使用统一模型 ID 路由所有代理请求。
#### Scenario: 提取统一模型 ID
- **WHEN** 收到 Chat、Embeddings 或 Rerank 接口的 POST 请求(含请求体)
- **THEN** SHALL 调用客户端协议 adapter 的 `ExtractModelName(body, ifaceType)` 提取 model 值
- **THEN** SHALL 调用 `ParseUnifiedModelID` 解析得到 providerID 和 modelName
- **THEN** SHALL 调用 `RoutingService.RouteByModelName(providerID, modelName)` 路由
#### Scenario: GET 请求或无请求体
- **WHEN** 收到 GET 请求或请求体为空
- **THEN** SHALL 返回错误响应,状态码为 400提示缺少 model 字段
#### Scenario: 无效的统一模型 ID
- **WHEN** 请求体中 `model` 字段不是有效的统一模型 ID 格式
- **THEN** SHALL 返回错误响应,状态码为 400
#### Scenario: 模型不存在
- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的 provider_id + model_name 组合
- **THEN** SHALL 返回错误响应,状态码为 404
#### Scenario: 模型已禁用
- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false
- **THEN** SHALL 返回错误响应,状态码为 404
#### Scenario: 供应商已禁用
- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false
- **THEN** SHALL 返回错误响应,状态码为 404
### Requirement: 同协议 Smart Passthrough
当客户端协议与供应商协议相同时ProxyHandler SHALL 使用 Smart Passthrough 处理 Chat、Embedding、Rerank 请求。
#### Scenario: 同协议非流式请求
- **WHEN** 客户端协议 == 供应商协议,且为非流式请求
- **THEN** SHALL 调用 adapter 的 `RewriteRequestModelName(body, modelName, ifaceType)` 将请求体中 model 从统一 ID 改写为上游模型名
- **THEN** SHALL 构建 URL 和 Headers同当前透传逻辑
- **THEN** SHALL 发送改写后的请求体到上游
- **THEN** SHALL 调用 adapter 的 `RewriteResponseModelName(resp.Body, unifiedModelID, ifaceType)` 将响应中 model 从上游名改写为统一 ID
- **THEN** SHALL NOT 对 body 做全量 decode → encode保持未改写字段的原始 bytes
#### Scenario: 同协议流式请求
- **WHEN** 客户端协议 == 供应商协议,且为流式请求
- **THEN** SHALL 对请求体做 `RewriteRequestModelName` 改写 model 字段
- **THEN** SHALL 逐 SSE chunk 调用 `RewriteResponseModelName` 改写响应中 model 字段
- **THEN** SHALL NOT 对 chunk 做全量 decode → encode
#### Scenario: Smart Passthrough 保真性
- **WHEN** 客户端发送含未知参数的请求(如 `{"model":"openai/gpt-4","some_new_param":"value"}`
- **THEN** 上游 SHALL 收到 `{"model":"gpt-4","some_new_param":"value"}`
- **THEN** `some_new_param` SHALL 保持原始值不变,不丢失、不改变类型
### Requirement: 跨协议完整转换
当客户端协议与供应商协议不同时ProxyHandler SHALL 使用全量转换路径。
#### Scenario: 跨协议非流式请求
- **WHEN** 客户端协议 != 供应商协议
- **THEN** SHALL 走 `ConvertHttpRequest` 全量转换encoder 中 provider.ModelName 覆盖 model
- **THEN** SHALL 走 `ConvertHttpResponse` 全量转换modelOverride 参数覆写 canonical.Model
#### Scenario: 跨协议流式请求
- **WHEN** 客户端协议 != 供应商协议,且为流式请求
- **THEN** SHALL 走 `CreateStreamConverter` 全量转换modelOverride 参数覆写流式 canonical 事件中的 Model
### Requirement: 模型列表本地聚合
ProxyHandler SHALL 从数据库聚合返回模型列表,不再透传上游。
#### Scenario: GET /v1/models
- **WHEN** 收到 `GET /{protocol}/v1/models` 请求
- **THEN** SHALL 从数据库查询所有 enabled 的模型(关联 enabled 的供应商)
- **THEN** SHALL 组装 `CanonicalModelList`,每个模型的 ID 字段为统一模型 ID`provider_id/model_name`Name 字段为 model_nameOwnedBy 字段为 provider_id
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
- **THEN** SHALL NOT 请求上游供应商
#### Scenario: 无可用模型
- **WHEN** 数据库中没有 enabled 的模型
- **THEN** SHALL 返回空列表
### Requirement: 模型详情本地查询
ProxyHandler SHALL 从数据库查询返回模型详情,不再透传上游。
#### Scenario: GET /v1/models/{unified_id}
- **WHEN** 收到 `GET /{protocol}/v1/models/{provider_id}/{model_name}` 请求
- **THEN** SHALL 调用 adapter 的 `ExtractUnifiedModelID` 提取统一模型 ID
- **THEN** SHALL 解析统一模型 ID 得到 providerID 和 modelName
- **THEN** SHALL 从数据库查询对应的模型和供应商
- **THEN** SHALL 组装 `CanonicalModelInfo`ID 字段为统一模型 ID`provider_id/model_name`Name 字段为 model_nameOwnedBy 字段为 provider_id
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
- **THEN** SHALL NOT 请求上游供应商
#### Scenario: 模型详情不存在
- **WHEN** 统一模型 ID 对应的模型不存在或已禁用
- **THEN** SHALL 返回错误响应,状态码为 404
### Requirement: 统计记录
ProxyHandler SHALL 使用 providerID 和 modelName 记录使用统计。
#### Scenario: 异步记录统计
- **WHEN** 代理请求成功完成
- **THEN** SHALL 异步调用 `StatsService.Record(providerID, modelName)`

View File

@@ -1,16 +0,0 @@
## ADDED Requirements
### Requirement: 使用统计记录统一模型标识
系统 SHALL 使用 providerID 和 modelName上游模型名记录使用统计。
#### Scenario: 代理请求统计记录
- **WHEN** 代理请求成功完成
- **THEN** SHALL 记录 provider_id 和 model_name 到 usage_stats 表(参数来自路由结果)
- **THEN** SHALL 异步执行,不阻塞响应
#### Scenario: 查询统计
- **WHEN** 查询统计数据
- **THEN** 支持按 provider_id 和 model_name 过滤

View File

@@ -1,53 +0,0 @@
## 1. 数据库迁移
- [ ] 1.1 新增迁移脚本DROP 旧 models 表 + CREATE 新 models 表id UUID PK, provider_id, model_name, enabled, created_atUNIQUE(provider_id, model_name)
- [ ] 1.2 更新 config/models.goModel 结构体适配id 改为 UUID 自动生成model_name 保持不变)
- [ ] 1.3 编写迁移脚本测试
## 2. 统一模型 ID 工具包
- [ ] 2.1 新增 pkg/modelid/model_id.go实现 ParseUnifiedModelID、FormatUnifiedModelID、ValidateProviderID、IsValidUnifiedModelID
- [ ] 2.2 新增 pkg/modelid/model_id_test.go覆盖标准格式、含斜杠 model_name、空字符串、非法字符等边界情况
## 3. Domain 层适配
- [ ] 3.1 修改 domain/model.goModel 结构体字段适配,新增 UnifiedModelID() 方法
- [ ] 3.2 修改 domain/route.goRouteResult 适配新字段
## 4. Repository 层适配
- [ ] 4.1 修改 repository/model_repo.go接口变更 — GetByModelName 改为 FindByProviderAndModelName新增 ListEnabled
- [ ] 4.2 修改 repository/model_repo_impl.go实现 FindByProviderAndModelNameWHERE provider_id=? AND model_name=?、ListEnabledJOIN providers WHERE enabled
- [ ] 4.3 编写 repository 层测试
## 5. Service 层适配
- [ ] 5.1 修改 service/routing_service.goRoute 接口改为 RouteByModelName(providerID, modelName string)
- [ ] 5.2 修改 service/routing_service_impl.go调用 FindByProviderAndModelName 替代 GetByModelName
- [ ] 5.3 修改 service/model_service.goCreate 生成 UUID、新增联合唯一校验方法
- [ ] 5.4 修改 service/model_service_impl.go实现联合唯一校验、UUID 生成
- [ ] 5.5 修改 service/provider_service_impl.goCreate 时调用 ValidateProviderID 校验 ID 字符集
- [ ] 5.6 编写 service 层测试
## 6. Conversion 层适配
- [ ] 6.1 修改 conversion/adapter.goProtocolAdapter 接口新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName 四个方法
- [ ] 6.2 修改 conversion/engine.goConvertHttpResponse 新增 modelOverride 参数(跨协议场景),各 convert*ResponseBody 中覆写 canonical ModelCreateStreamConverter 新增 modelOverride 参数
- [ ] 6.3 修改 conversion/openai/adapter.go实现 ExtractUnifiedModelID、ExtractModelName按 ifaceType 提取 model、RewriteRequestModelName 和 RewriteResponseModelNamejson.RawMessage 最小化改写,按 ifaceType 定位 model 字段,请求/响应独立实现),修改 isModelInfoPath 允许 suffix 含 "/"
- [ ] 6.4 修改 conversion/anthropic/adapter.go实现 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName修改 isModelInfoPath 允许 suffix 含 "/"
- [ ] 6.5 编写 conversion 层测试ExtractUnifiedModelID、ExtractModelName 各 ifaceType、RewriteRequestModelName/RewriteResponseModelName 保真性含未知参数不丢失、isModelInfoPath 含斜杠路径、modelOverride 覆写
## 7. Handler 层改造
- [ ] 7.1 修改 handler/proxy_handler.goHandleProxy 按接口类型分发 — Models/ModelInfo 本地聚合Chat/Embed/Rerank 用 adapter.ExtractModelName 提取统一 ID 路由,同协议走 Smart Passthroughadapter.RewriteRequestModelName 改写请求、adapter.RewriteResponseModelName 改写响应跨协议走全量转换modelOverride删除 forwardPassthrough 和硬编码的 extractModelName
- [ ] 7.2 修改 handler/model_handler.go请求体字段适配移除 id 输入、保留 provider_id 和 model_name响应新增 unified_idCreate 使用 UUID
- [ ] 7.3 修改 handler/provider_handler.goCreateProvider 校验 ID 字符集
- [ ] 7.4 编写 handler 层测试:统一模型 ID 路由、同协议 Smart Passthrough 保真性、跨协议 modelOverride、Models 聚合、ModelInfo 查询、流式场景 model 覆写、provider ID 校验
## 8. 路由注册适配
- [ ] 8.1 修改 cmd/server/main.gosetupRoutes 适配 handler 签名变更,传递新增依赖
## 9. 文档更新
- [ ] 9.1 按需更新 README.md同步 models 表结构、API 接口字段、统一模型 ID 格式、Smart Passthrough 策略等变更说明

View File

@@ -278,3 +278,50 @@ ErrorCode SHALL 包含INVALID_INPUT、MISSING_REQUIRED_FIELD、INCOMPATIBLE_F
- **THEN** SHALL 从 provider.api_key 提取认证信息 - **THEN** SHALL 从 provider.api_key 提取认证信息
- **THEN** SHALL 从 provider.adapter_config 提取协议专属配置 - **THEN** SHALL 从 provider.adapter_config 提取协议专属配置
- **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段 - **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段
### Requirement: 跨协议响应转换支持 model 覆写
ConversionEngine SHALL 在跨协议响应转换时支持 model 字段覆写。
#### Scenario: ConvertHttpResponse 接收 modelOverride 参数
- **WHEN** 调用 `ConvertHttpResponse` 时传入 `modelOverride` 参数(跨协议场景,非空字符串)
- **THEN** SHALL 在解码上游响应到 canonical 后,将 `Model` 字段设为 `modelOverride`
- **THEN** SHALL 使用覆写后的 canonical 编码为客户端协议格式
#### Scenario: modelOverride 为空
- **WHEN** 调用 `ConvertHttpResponse``modelOverride` 为空字符串
- **THEN** SHALL NOT 覆写 canonical 的 Model 字段,保持上游原始值
#### Scenario: Chat 响应 model 覆写
- **WHEN** 跨协议转换 Chat 类型响应且 `modelOverride` 非空
- **THEN** `CanonicalResponse.Model` SHALL 被设为 `modelOverride`
#### Scenario: Embedding 响应 model 覆写
- **WHEN** 跨协议转换 Embedding 类型响应且 `modelOverride` 非空
- **THEN** `CanonicalEmbeddingResponse.Model` SHALL 被设为 `modelOverride`
#### Scenario: Rerank 响应 model 覆写
- **WHEN** 跨协议转换 Rerank 类型响应且 `modelOverride` 非空
- **THEN** `CanonicalRerankResponse.Model` SHALL 被设为 `modelOverride`
### Requirement: 跨协议流式转换支持 model 覆写
ConversionEngine SHALL 在跨协议流式转换时支持 model 字段覆写。
#### Scenario: CreateStreamConverter 接收 modelOverride 参数
- **WHEN** 调用 `CreateStreamConverter` 时传入 `modelOverride` 参数(跨协议场景)
- **THEN** SHALL 在流式 canonical 事件中将 `Model` 字段设为 `modelOverride`
### Requirement: TargetProvider 字段语义
TargetProvider 的 ModelName 字段 SHALL 存储上游供应商的模型名称(即 `model_name` 字段值),语义保持不变。
#### Scenario: encoder 使用 TargetProvider.ModelName
- **WHEN** 协议适配器编码请求时
- **THEN** SHALL 使用 `TargetProvider.ModelName` 作为发给上游的 `model` 字段值(值为路由结果中的 model_name

View File

@@ -28,9 +28,10 @@
#### Scenario: 初始迁移文件 #### Scenario: 初始迁移文件
- **WHEN** 创建初始迁移 - **WHEN** 创建初始迁移
- **THEN** SHALL 创建 001_initial_schema.sql - **THEN** SHALL 创建单个初始迁移文件(如 `20260421000001_initial_schema.sql`
- **THEN** SHALL 包含 providers、models、usage_stats 表的创建语句 - **THEN** SHALL 包含 providers、models、usage_stats 表的创建语句
- **THEN** SHALL 包含外键约束 - **THEN** SHALL 包含外键约束
- **THEN** SHALL 包含索引创建语句
#### Scenario: Up 迁移 #### Scenario: Up 迁移
@@ -42,25 +43,19 @@
#### Scenario: Down 迁移 #### Scenario: Down 迁移
- **WHEN** 执行 down 迁移 - **WHEN** 执行 down 迁移
- **THEN** SHALL 删除所有表 - **THEN** SHALL 删除所有表和索引
- **THEN** SHALL 按正确顺序删除(避免外键约束错误) - **THEN** SHALL 按正确顺序删除(避免外键约束错误)
### Requirement: 添加索引迁移 ### Requirement: models 表 schema 变更
系统 SHALL 创建索引迁移 系统 SHALL 在初始迁移脚本中直接创建新的 models 表结构(服务未上线,无需考虑数据迁移,迁移脚本已合并为单个初始迁移文件)
#### Scenario: 索引迁移文件 #### Scenario: 初始迁移 models 表结构
- **WHEN** 创建索引迁移 - **WHEN** 执行迁移
- **THEN** SHALL 创建 002_add_indexes.sql - **THEN** SHALL CREATE models 表包含字段idTEXT PRIMARY KEY存储 UUID 字符串、provider_idTEXT NOT NULL、model_nameTEXT NOT NULL、enabledINTEGER DEFAULT 1、created_atDATETIME
- **THEN** SHALL 为常用查询字段添加索引 - **THEN** SHALL 存在 UNIQUE(provider_id, model_name) 约束
- **THEN** SHALL 存在 FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
#### Scenario: 索引定义
- **WHEN** 添加索引
- **THEN** SHALL 为 models(provider_id) 添加索引
- **THEN** SHALL 为 models(model_name) 添加索引
- **THEN** SHALL 为 usage_stats(provider_id, model_name, date) 添加复合索引
### Requirement: 迁移命令集成 ### Requirement: 迁移命令集成

View File

@@ -0,0 +1,209 @@
# Error Responses
## Purpose
定义系统统一的错误响应格式和各类错误场景,确保客户端能够一致地处理错误。
## Requirements
### Requirement: 统一错误响应格式
系统 SHALL 使用统一的错误响应格式。
#### Scenario: 标准错误格式
- **WHEN** 返回错误响应
- **THEN** SHALL 使用以下 JSON 格式:
```json
{
"error": "错误描述",
"code": "ERROR_CODE"
}
```
- **THEN** `error` 字段 SHALL 包含人类可读的错误描述
- **THEN** `code` 字段 SHALL 包含机器可读的错误代码(可选)
### Requirement: provider_id 校验错误
系统 SHALL 对 provider_id 校验错误返回明确的错误信息。
#### Scenario: provider_id 包含非法字符
- **WHEN** 创建或更新供应商时provider_id 包含非 `[a-zA-Z0-9_]` 字符
- **THEN** SHALL 返回 HTTP 400 Bad Request
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "供应商 ID 仅允许字母、数字、下划线",
"code": "INVALID_PROVIDER_ID"
}
```
#### Scenario: provider_id 长度超限
- **WHEN** 创建或更新供应商时provider_id 长度超过 64
- **THEN** SHALL 返回 HTTP 400 Bad Request
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "供应商 ID 长度不能超过 64 个字符",
"code": "INVALID_PROVIDER_ID"
}
```
### Requirement: 联合唯一约束冲突错误
系统 SHALL 对联合唯一约束冲突返回明确的错误信息。
#### Scenario: 创建模型时 provider_id + model_name 组合已存在
- **WHEN** 创建模型时provider_id + model_name 组合已存在
- **THEN** SHALL 返回 HTTP 409 Conflict
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "同一供应商下模型名称已存在",
"code": "duplicate_model"
}
```
#### Scenario: 更新模型时导致 provider_id + model_name 组合冲突
- **WHEN** 更新模型时,修改 provider_id 或 model_name 导致与已有记录冲突
- **THEN** SHALL 返回 HTTP 409 Conflict
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "同一供应商下模型名称已存在",
"code": "duplicate_model"
}
```
### Requirement: 资源不存在错误
系统 SHALL 对资源不存在返回明确的错误信息。
#### Scenario: 模型不存在
- **WHEN** 查询或操作不存在的模型
- **THEN** SHALL 返回 HTTP 404 Not Found
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "模型未找到"
}
```
#### Scenario: 供应商不存在
- **WHEN** 创建模型时指定的供应商不存在
- **THEN** SHALL 返回 HTTP 400 Bad Request
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "供应商不存在"
}
```
### Requirement: 统一模型 ID 格式错误
系统 SHALL 对统一模型 ID 格式错误返回明确的错误信息。
#### Scenario: 统一模型 ID 格式无效
- **WHEN** 代理请求中的 model 字段不是有效的统一模型 ID 格式
- **THEN** 请求 SHALL 走 forwardPassthrough 透传到上游(兼容未适配的客户端)
- **THEN** 不返回错误,保持与上游的兼容性
**设计理由:**
- 统一模型 ID 是 BREAKING CHANGE部分旧客户端可能仍使用原始模型名
- 透传策略允许上游自行判断并返回错误(如 404 model not found
- 网关作为透明代理,不应拦截所有格式非法的请求
#### Scenario: 统一模型 ID 格式有效但对应模型不存在
- **WHEN** 代理请求中的 model 字段是有效的统一模型 ID 格式(含 `/`),但数据库中找不到对应的模型
- **THEN** SHALL 返回 HTTP 404 Not Found
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "模型未找到"
}
```
#### Scenario: 统一模型 ID 对应的模型不存在
- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的模型
- **THEN** SHALL 返回 HTTP 404 Not Found
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "模型未找到"
}
```
#### Scenario: 统一模型 ID 对应的模型已禁用
- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false
- **THEN** SHALL 返回 HTTP 404 Not Found
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "模型未找到"
}
```
#### Scenario: 统一模型 ID 对应的供应商已禁用
- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false
- **THEN** SHALL 返回 HTTP 404 Not Found
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "模型未找到"
}
```
### Requirement: JSON 格式错误
系统 SHALL 对请求体 JSON 格式错误返回明确的错误信息。
#### Scenario: 请求体 JSON 格式错误
- **WHEN** 代理请求的请求体不是有效的 JSON 格式
- **THEN** SHALL 返回 HTTP 400 Bad Request
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "请求体 JSON 格式错误",
"code": "INVALID_JSON"
}
```
#### Scenario: Smart Passthrough 时请求体 JSON 格式错误
- **WHEN** 同协议 Smart Passthrough 场景下,请求体 JSON 格式不正确
- **THEN** SHALL 返回 HTTP 400 Bad Request
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "请求体 JSON 格式错误",
"code": "INVALID_JSON"
}
```
### Requirement: 不可变字段错误
系统 SHALL 对尝试修改不可变字段返回明确的错误信息。
#### Scenario: 尝试修改供应商 ID
- **WHEN** 更新供应商时,请求体中包含 `id` 字段
- **THEN** SHALL 返回 HTTP 400 Bad Request
- **THEN** SHALL 返回以下 JSON 格式:
```json
{
"error": "供应商 ID 不允许修改",
"code": "IMMUTABLE_FIELD"
}
```

View File

@@ -10,10 +10,12 @@
#### Scenario: 使用有效数据创建模型 #### Scenario: 使用有效数据创建模型
- **WHEN** 向 `/api/models` 发送 POST 请求,携带有效的模型数据(id, provider_id, model_name - **WHEN** 向 `/api/models` 发送 POST 请求携带有效的模型数据provider_id, model_name,不提供 id 字段
- **THEN** 网关 SHALL 自动生成 UUID 作为模型 id
- **THEN** 网关 SHALL 在数据库中创建新的模型记录 - **THEN** 网关 SHALL 在数据库中创建新的模型记录
- **THEN** 网关 SHALL 返回创建的模型,状态码为 201 - **THEN** 网关 SHALL 返回创建的模型,状态码为 201
- **THEN** 模型 SHALL 默认启用 - **THEN** 模型 SHALL 默认启用
- **THEN** 返回的模型 SHALL 包含 `unified_id` 字段,值为 `{provider_id}/{model_name}`
#### Scenario: 使用不存在的供应商创建模型 #### Scenario: 使用不存在的供应商创建模型
@@ -21,7 +23,11 @@
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) - **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示供应商不存在 - **THEN** 错误 SHALL 指示供应商不存在
**变更说明:** handler 通过 ModelService 调用,数据访问通过 ModelRepository 和 ProviderRepository。API 接口保持不变。 #### Scenario: 创建重复模型
- **WHEN** 向 `/api/models` 发送 POST 请求,携带已存在的 provider_id + model_name 组合
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
- **THEN** 错误 SHALL 指示同一供应商下模型名称已存在
### Requirement: 列出所有模型 ### Requirement: 列出所有模型
@@ -31,9 +37,7 @@
- **WHEN** 向 `/api/models` 发送 GET 请求 - **WHEN** 向 `/api/models` 发送 GET 请求
- **THEN** 网关 SHALL 返回所有模型的列表 - **THEN** 网关 SHALL 返回所有模型的列表
- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, enabled, created_at - **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, unified_id, enabled, created_at
**变更说明:** 数据访问从 config 包迁移到 ModelRepository。API 接口保持不变。
### Requirement: 按供应商列出模型 ### Requirement: 按供应商列出模型
@@ -43,8 +47,7 @@
- **WHEN** 向 `/api/models?provider_id=<provider_id>` 发送 GET 请求 - **WHEN** 向 `/api/models?provider_id=<provider_id>` 发送 GET 请求
- **THEN** 网关 SHALL 返回指定供应商的模型列表 - **THEN** 网关 SHALL 返回指定供应商的模型列表
- **THEN** 每个模型 SHALL 包含 unified_id 字段
**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。
### Requirement: 更新模型配置 ### Requirement: 更新模型配置
@@ -55,14 +58,12 @@
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据 - **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据
- **THEN** 网关 SHALL 更新数据库中的模型记录 - **THEN** 网关 SHALL 更新数据库中的模型记录
- **THEN** 网关 SHALL 返回更新后的模型 - **THEN** 网关 SHALL 返回更新后的模型
- **THEN** 返回的模型 SHALL 包含更新后的 unified_id
#### Scenario: 更新模型供应商 #### Scenario: 更新模型为重复组合
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带新的 provider_id - **WHEN** 向 `/api/models/:id` 发送 PUT 请求,更新 provider_id 或 model_name 导致与已有记录重复
- **THEN** 网关 SHALL 验证新供应商是否存在 - **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
- **THEN** 网关 SHALL 更新模型的供应商关联
**变更说明:** 通过 ModelService、ModelRepository 和 ProviderRepository 实现。API 接口保持不变。
### Requirement: 删除模型配置 ### Requirement: 删除模型配置
@@ -74,8 +75,6 @@
- **THEN** 网关 SHALL 删除模型记录 - **THEN** 网关 SHALL 删除模型记录
- **THEN** 网关 SHALL 返回状态码 204 (No Content) - **THEN** 网关 SHALL 返回状态码 204 (No Content)
**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。
### Requirement: 使用 service 层处理业务逻辑 ### Requirement: 使用 service 层处理业务逻辑
Handler SHALL 通过 ModelService 处理业务逻辑。 Handler SHALL 通过 ModelService 处理业务逻辑。
@@ -85,25 +84,60 @@ Handler SHALL 通过 ModelService 处理业务逻辑。
- **WHEN** handler 收到请求 - **WHEN** handler 收到请求
- **THEN** SHALL 调用对应的 ModelService 方法Create、Get、List、Update、Delete - **THEN** SHALL 调用对应的 ModelService 方法Create、Get、List、Update、Delete
- **THEN** SHALL 使用 domain.Model 类型 - **THEN** SHALL 使用 domain.Model 类型
- **THEN** Create 时 SHALL 调用 `uuid.New()` 生成 id
#### Scenario: 供应商验证 #### Scenario: 供应商验证和唯一性校验
- **WHEN** 创建或更新模型 - **WHEN** 创建或更新模型
- **THEN** SHALL 在 service 层验证供应商存在 - **THEN** SHALL 在 service 层验证供应商存在
- **THEN** SHALL 通过 ProviderRepository 查询供应商 - **THEN** SHALL 在 service 层验证 provider_id + model_name 联合唯一
### Requirement: 联合唯一约束并发处理
创建或更新模型时SHALL 使用应用层校验 + 数据库约束双重保险处理联合唯一约束。
#### Scenario: 应用层快速失败
- **WHEN** 创建或更新模型前
- **THEN** SHALL 先检查 provider_id + model_name 是否已存在
- **THEN** 如已存在SHALL 返回 HTTP 409 Conflict
- **THEN** SHALL 返回错误格式:
```json
{
"error": "同一供应商下模型名称已存在",
"code": "duplicate_model"
}
```
#### Scenario: 数据库约束兜底
- **WHEN** 并发创建导致应用层校验通过但数据库写入失败
- **THEN** SHALL 捕获数据库 UNIQUE 约束错误
- **THEN** SHALL 转换为 HTTP 409 Conflict 错误返回
- **THEN** SHALL 返回错误格式:
```json
{
"error": "同一供应商下模型名称已存在",
"code": "duplicate_model"
}
```
#### Scenario: SQLite UNIQUE 约束错误检测
- **WHEN** 捕获数据库错误
- **THEN** SHALL 检查错误信息是否包含 "UNIQUE constraint failed"
- **THEN** 如匹配SHALL 识别为联合唯一约束冲突
### Requirement: 使用 repository 层访问数据 ### Requirement: 使用 repository 层访问数据
Service SHALL 通过 ModelRepository 访问数据。 Service SHALL 通过 ModelRepository 访问数据。
#### Scenario: 调用 repository 方法 #### Scenario: 联合查询
- **WHEN** service 处理业务逻辑 - **WHEN** service 需要按 provider 和 model_name 查询模型
- **THEN** SHALL 调用对应的 ModelRepository 方法 - **THEN** SHALL 调用 `FindByProviderAndModelName(providerID, modelName)` 方法
- **THEN** SHALL 使用 domain.Model 类型
#### Scenario: 数据验证 #### Scenario: 查询所有启用模型
- **WHEN** 创建或更新模型 - **WHEN** proxy handler 需要聚合模型列表
- **THEN** SHALL 在 service 层验证业务规则 - **THEN** SHALL 调用 `ListEnabled()` 方法,返回所有 enabled 的模型(关联 enabled 的供应商)
- **THEN** SHALL 在 repository 层执行数据库操作

View File

@@ -271,3 +271,72 @@ Decoder 几乎 1:1 映射,维护最小状态机:
- **WHEN** interfaceType 为 EMBEDDINGS 或 RERANK - **WHEN** interfaceType 为 EMBEDDINGS 或 RERANK
- **THEN** supportsInterface SHALL 返回 false - **THEN** supportsInterface SHALL 返回 false
- **THEN** 引擎 SHALL 走透传或返回空响应 - **THEN** 引擎 SHALL 走透传或返回空响应
### Requirement: 模型详情路径识别
Anthropic 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
#### Scenario: 含斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/anthropic/claude-3-opus`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 含多段斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/azure/deployments/gpt-4`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 模型列表路径不受影响
- **WHEN** 路径为 `/v1/models`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
### Requirement: 提取统一模型 ID
Anthropic 适配器 SHALL 从路径中提取统一模型 ID。
#### Scenario: 标准路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/anthropic/claude-3-opus")`
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
#### Scenario: 复杂路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
#### Scenario: 非模型详情路径
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
- **THEN** SHALL 返回错误
### Requirement: 从请求体提取 model
Anthropic 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
#### Scenario: Chat 请求提取 model
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`body 为 `{"model":"anthropic/claude-3-opus","messages":[...]}`
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
#### Scenario: 无 model 字段
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`body 中不含 model 字段
- **THEN** SHALL 返回错误
### Requirement: 最小化改写请求体 model 字段
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
- **WHEN** 调用 `RewriteRequestModelName(body, "claude-3-opus-20240229", InterfaceTypeChat)`
- **THEN** SHALL 将请求体中 model 字段替换为 `"claude-3-opus-20240229"`,其余字段原样保留
### Requirement: 最小化改写响应体 model 字段
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
#### Scenario: 响应体 model 改写(上游名 → 统一 ID
- **WHEN** 调用 `RewriteResponseModelName(body, "anthropic/claude-3-opus", InterfaceTypeChat)`
- **THEN** SHALL 将响应体中 model 字段替换为 `"anthropic/claude-3-opus"`,其余字段原样保留

View File

@@ -270,3 +270,90 @@ Encoder SHALL 维护状态:
- **WHEN** 解码/编码 rerank 请求和响应 - **WHEN** 解码/编码 rerank 请求和响应
- **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射 - **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射
### Requirement: 模型详情路径识别
OpenAI 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
#### Scenario: 含斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/openai/gpt-4`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 含多段斜杠的统一模型 ID 路径
- **WHEN** 路径为 `/v1/models/azure/accounts/org-123/models/gpt-4`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
#### Scenario: 模型列表路径不受影响
- **WHEN** 路径为 `/v1/models`
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
### Requirement: 提取统一模型 ID
OpenAI 适配器 SHALL 从路径中提取统一模型 ID。
#### Scenario: 标准路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/openai/gpt-4")`
- **THEN** SHALL 返回 `"openai/gpt-4"`
#### Scenario: 复杂路径提取
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
#### Scenario: 非模型详情路径
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
- **THEN** SHALL 返回错误
### Requirement: 从请求体提取 model
OpenAI 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
#### Scenario: Chat 请求提取 model
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`body 为 `{"model":"openai/gpt-4","messages":[...]}`
- **THEN** SHALL 返回 `"openai/gpt-4"`
#### Scenario: Embedding 请求提取 model
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeEmbeddings)`body 为 `{"model":"openai/text-embedding-3","input":"text"}`
- **THEN** SHALL 返回 `"openai/text-embedding-3"`
#### Scenario: 无 model 字段
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`body 中不含 model 字段
- **THEN** SHALL 返回错误
### Requirement: 最小化改写请求体 model 字段
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeChat)`body 为 `{"model":"openai/gpt-4","messages":[...],"some_param":"value"}`
- **THEN** SHALL 返回 `{"model":"gpt-4","messages":[...],"some_param":"value"}`
- **THEN** 除 model 外的字段 SHALL 保持原始 bytes 不变
#### Scenario: 不同 InterfaceType 的请求改写
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeEmbeddings)`
- **THEN** SHALL 按 Embedding 接口的请求体 model 字段位置进行改写
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeRerank)`
- **THEN** SHALL 按 Rerank 接口的请求体 model 字段位置进行改写
### Requirement: 最小化改写响应体 model 字段
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
#### Scenario: 响应体 model 改写(上游名 → 统一 ID
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeChat)`body 为上游 Chat 响应
- **THEN** SHALL 将 model 字段替换为 `"openai/gpt-4"`,其余字段原样保留
#### Scenario: 不同 InterfaceType 的响应改写
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeEmbeddings)`
- **THEN** SHALL 按 Embedding 接口的响应体 model 字段位置进行改写

View File

@@ -29,7 +29,44 @@
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request) - **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示缺少哪些字段 - **THEN** 错误 SHALL 指示缺少哪些字段
**变更说明:** handler 通过 ProviderService 调用,数据访问通过 ProviderRepository。API 接口保持不变。 #### Scenario: 创建供应商时 ID 包含非法字符
- **WHEN** 向 `/api/providers` 发送 POST 请求id 包含非 `[a-zA-Z0-9_]` 字符
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示 id 仅允许字母、数字、下划线
#### Scenario: 创建供应商时 ID 过长
- **WHEN** 向 `/api/providers` 发送 POST 请求id 长度超过 64
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
### Requirement: 供应商 ID 不允许修改
供应商 ID 是主键,用于构建统一模型 ID不允许修改。
#### Scenario: 尝试修改供应商 ID
- **WHEN** 向 `/api/providers/:id` 发送 PUT 请求,请求体中包含 `id` 字段
- **THEN** ProviderService.Update SHALL 在 service 层校验并返回错误
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
- **THEN** 错误 SHALL 指示供应商 ID 不允许修改
- **THEN** 错误格式 SHALL 为:
```json
{
"error": "供应商 ID 不允许修改",
"code": "IMMUTABLE_FIELD"
}
```
**校验位置:** Service 层(`ProviderService.Update`
- Service 层校验保证所有调用方handler、CLI、未来 API统一遵守
- Handler 层负责捕获 `ErrImmutableField` 并转换为 HTTP 400 响应
**原因:**
- 供应商 ID 是主键,用于构建统一模型 IDprovider_id/model_name
- 修改 ID 会导致所有统一模型 ID 失效
- 客户端缓存的模型 ID 全部失效
- 如需修改,应创建新供应商并迁移模型
### Requirement: 列出所有供应商 ### Requirement: 列出所有供应商

View File

@@ -92,6 +92,44 @@
- **THEN** SHALL 验证至少提供一个可更新字段 - **THEN** SHALL 验证至少提供一个可更新字段
- **THEN** SHALL 验证字段值有效性 - **THEN** SHALL 验证字段值有效性
### Requirement: 供应商 ID 校验
创建供应商时SHALL 对 `id` 字段进行字符集校验。
#### Scenario: 合法字符集
- **WHEN** 创建供应商id 仅包含 `[a-zA-Z0-9_]` 字符
- **THEN** SHALL 校验通过
#### Scenario: 非法字符
- **WHEN** 创建供应商id 包含 `-``.``/`、空格、中文等非 `[a-zA-Z0-9_]` 字符
- **THEN** SHALL 返回 400 错误
#### Scenario: 长度限制
- **WHEN** 创建供应商id 长度超过 64
- **THEN** SHALL 返回 400 错误
### Requirement: 模型创建校验
创建模型时SHALL 对 `provider_id` + `model_name` 进行联合唯一性校验。
#### Scenario: 正常创建
- **WHEN** 创建模型provider_id 存在且 provider_id + model_name 组合唯一
- **THEN** SHALL 校验通过
#### Scenario: 联合唯一冲突
- **WHEN** 创建模型provider_id + model_name 组合已存在
- **THEN** SHALL 返回 409 错误
#### Scenario: model_name 为空
- **WHEN** 创建模型,未提供 model_name
- **THEN** SHALL 返回 400 错误
### Requirement: 返回友好的验证错误 ### Requirement: 返回友好的验证错误
系统 SHALL 返回友好的验证错误响应。 系统 SHALL 返回友好的验证错误响应。

View File

@@ -1,4 +1,10 @@
## ADDED Requirements # Unified Model ID
## Purpose
定义统一模型 ID 的格式、解析、格式化和校验规则,确保跨协议的模型标识一致性。
## Requirements
### Requirement: 解析统一模型 ID ### Requirement: 解析统一模型 ID

View File

@@ -106,3 +106,124 @@ ProxyHandler SHALL 支持 GET 请求的扩展层接口代理。
- **THEN** SHALL 调用 providerClient.Send 发送请求 - **THEN** SHALL 调用 providerClient.Send 发送请求
- **THEN** SHALL 调用 engine.convertHttpResponse 转换响应格式 - **THEN** SHALL 调用 engine.convertHttpResponse 转换响应格式
- **THEN** SHALL 返回转换后的响应 - **THEN** SHALL 返回转换后的响应
### Requirement: 代理请求路由
ProxyHandler SHALL 使用统一模型 ID 路由所有代理请求。
#### Scenario: 提取统一模型 ID
- **WHEN** 收到 Chat、Embeddings 或 Rerank 接口的 POST 请求(含请求体)
- **THEN** SHALL 调用客户端协议 adapter 的 `ExtractModelName(body, ifaceType)` 提取 model 值
- **THEN** SHALL 调用 `ParseUnifiedModelID` 解析得到 providerID 和 modelName
- **THEN** SHALL 调用 `RoutingService.RouteByModelName(providerID, modelName)` 路由
#### Scenario: GET 请求或无请求体
- **WHEN** 收到 GET 请求或请求体为空或请求体中无法提取 model 字段
- **THEN** SHALL 走 forwardPassthrough 透传到上游供应商(兼容未适配的客户端和无 body 请求)
#### Scenario: 无效的统一模型 ID
- **WHEN** 请求体中 `model` 字段不是有效的统一模型 ID 格式(不含 `/`
- **THEN** SHALL 走 forwardPassthrough 透传到上游供应商(兼容使用原始模型名的客户端)
#### Scenario: 模型不存在
- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的 provider_id + model_name 组合
- **THEN** SHALL 返回错误响应,状态码为 404
#### Scenario: 模型已禁用
- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false
- **THEN** SHALL 返回错误响应,状态码为 404
#### Scenario: 供应商已禁用
- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false
- **THEN** SHALL 返回错误响应,状态码为 404
### Requirement: 同协议 Smart Passthrough
当客户端协议与供应商协议相同时ProxyHandler SHALL 使用 Smart Passthrough 处理 Chat、Embedding、Rerank 请求。
#### Scenario: 同协议非流式请求
- **WHEN** 客户端协议 == 供应商协议,且为非流式请求
- **THEN** SHALL 调用 adapter 的 `RewriteRequestModelName(body, modelName, ifaceType)` 将请求体中 model 从统一 ID 改写为上游模型名
- **THEN** SHALL 构建 URL 和 Headers同当前透传逻辑
- **THEN** SHALL 发送改写后的请求体到上游
- **THEN** SHALL 调用 adapter 的 `RewriteResponseModelName(resp.Body, unifiedModelID, ifaceType)` 将响应中 model 从上游名改写为统一 ID
- **THEN** SHALL NOT 对 body 做全量 decode → encode保持未改写字段的原始 bytes
#### Scenario: 同协议流式请求
- **WHEN** 客户端协议 == 供应商协议,且为流式请求
- **THEN** SHALL 对请求体做 `RewriteRequestModelName` 改写 model 字段
- **THEN** SHALL 逐 SSE chunk 调用 `RewriteResponseModelName` 改写响应中 model 字段
- **THEN** SHALL NOT 对 chunk 做全量 decode → encode
#### Scenario: Smart Passthrough 保真性
- **WHEN** 客户端发送含未知参数的请求(如 `{"model":"openai/gpt-4","some_new_param":"value"}`
- **THEN** 上游 SHALL 收到 `{"model":"gpt-4","some_new_param":"value"}`
- **THEN** `some_new_param` SHALL 保持原始值不变,不丢失、不改变类型
### Requirement: 跨协议完整转换
当客户端协议与供应商协议不同时ProxyHandler SHALL 使用全量转换路径。
#### Scenario: 跨协议非流式请求
- **WHEN** 客户端协议 != 供应商协议
- **THEN** SHALL 走 `ConvertHttpRequest` 全量转换encoder 中 provider.ModelName 覆盖 model
- **THEN** SHALL 走 `ConvertHttpResponse` 全量转换modelOverride 参数覆写 canonical.Model
#### Scenario: 跨协议流式请求
- **WHEN** 客户端协议 != 供应商协议,且为流式请求
- **THEN** SHALL 走 `CreateStreamConverter` 全量转换modelOverride 参数覆写流式 canonical 事件中的 Model
### Requirement: 模型列表本地聚合
ProxyHandler SHALL 从数据库聚合返回模型列表,不再透传上游。
#### Scenario: GET /v1/models
- **WHEN** 收到 `GET /{protocol}/v1/models` 请求
- **THEN** SHALL 从数据库查询所有 enabled 的模型(关联 enabled 的供应商)
- **THEN** SHALL 组装 `CanonicalModelList`,每个模型的 ID 字段为统一模型 ID`provider_id/model_name`Name 字段为 model_nameOwnedBy 字段为 provider_id
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
- **THEN** SHALL NOT 请求上游供应商
#### Scenario: 无可用模型
- **WHEN** 数据库中没有 enabled 的模型
- **THEN** SHALL 返回空列表
### Requirement: 模型详情本地查询
ProxyHandler SHALL 从数据库查询返回模型详情,不再透传上游。
#### Scenario: GET /v1/models/{unified_id}
- **WHEN** 收到 `GET /{protocol}/v1/models/{provider_id}/{model_name}` 请求
- **THEN** SHALL 调用 adapter 的 `ExtractUnifiedModelID` 提取统一模型 ID
- **THEN** SHALL 解析统一模型 ID 得到 providerID 和 modelName
- **THEN** SHALL 从数据库查询对应的模型和供应商
- **THEN** SHALL 组装 `CanonicalModelInfo`ID 字段为统一模型 ID`provider_id/model_name`Name 字段为 model_nameOwnedBy 字段为 provider_id
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
- **THEN** SHALL NOT 请求上游供应商
#### Scenario: 模型详情不存在
- **WHEN** 统一模型 ID 对应的模型不存在或已禁用
- **THEN** SHALL 返回错误响应,状态码为 404
### Requirement: 统计记录
ProxyHandler SHALL 使用 providerID 和 modelName 记录使用统计。
#### Scenario: 异步记录统计
- **WHEN** 代理请求成功完成
- **THEN** SHALL 异步调用 `StatsService.Record(providerID, modelName)`

View File

@@ -22,7 +22,20 @@
- **THEN** 网关 SHALL 增加该供应商和模型的请求计数 - **THEN** 网关 SHALL 增加该供应商和模型的请求计数
- **THEN** 网关 SHALL 在流结束后记录统计 - **THEN** 网关 SHALL 在流结束后记录统计
**变更说明:** 统计记录通过 StatsService 调用,数据访问通过 StatsRepository。API 接口保持不变。 ### Requirement: 使用统计记录统一模型标识
系统 SHALL 使用 providerID 和 modelName上游模型名记录使用统计。
#### Scenario: 代理请求统计记录
- **WHEN** 代理请求成功完成
- **THEN** SHALL 记录 provider_id 和 model_name 到 usage_stats 表(参数来自路由结果)
- **THEN** SHALL 异步执行,不阻塞响应
#### Scenario: 查询统计
- **WHEN** 查询统计数据
- **THEN** 支持按 provider_id 和 model_name 过滤
### Requirement: 按供应商查询统计 ### Requirement: 按供应商查询统计