feat: 实现统一模型 ID 机制
实现统一模型 ID 格式 (provider_id/model_name),支持跨协议模型标识和 Smart Passthrough。 核心变更: - 新增 pkg/modelid 包:解析、格式化、校验统一模型 ID - 数据库迁移:models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 约束 - Repository 层:FindByProviderAndModelName、ListEnabled 方法 - Service 层:联合唯一校验、provider ID 字符集校验 - Conversion 层:ExtractModelName、RewriteRequestModelName/RewriteResponseModelName 方法 - Handler 层:统一模型 ID 路由、Smart Passthrough、Models API 本地聚合 - 新增 error-responses、unified-model-id 规范 测试覆盖: - 单元测试:modelid、conversion、handler、service、repository - 集成测试:统一模型 ID 路由、Smart Passthrough 保真性、跨协议转换 - 迁移测试:UUID 主键、UNIQUE 约束、级联删除 OpenSpec: - 归档 unified-model-id 变更到 archive/2026-04-21-unified-model-id - 同步 11 个 delta specs 到 main specs - 新增 error-responses、unified-model-id 规范文件
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -404,4 +404,5 @@ cython_debug/
|
|||||||
openspec/changes/archive
|
openspec/changes/archive
|
||||||
temp
|
temp
|
||||||
.agents
|
.agents
|
||||||
skills-lock.json
|
skills-lock.json
|
||||||
|
.worktrees
|
||||||
19
README.md
19
README.md
@@ -38,10 +38,11 @@ nex/
|
|||||||
## 功能特性
|
## 功能特性
|
||||||
|
|
||||||
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
||||||
- **透明代理**:对 OpenAI 兼容供应商透传请求
|
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
||||||
|
- **透明代理**:对 OpenAI 兼容供应商 Smart Passthrough,最小化改写保持参数保真
|
||||||
- **流式响应**:完整支持 SSE 流式传输
|
- **流式响应**:完整支持 SSE 流式传输
|
||||||
- **Function Calling**:支持工具调用(Tools)
|
- **Function Calling**:支持工具调用(Tools)
|
||||||
- **多供应商管理**:配置和管理多个供应商
|
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
|
||||||
- **用量统计**:按供应商、模型、日期统计请求数量
|
- **用量统计**:按供应商、模型、日期统计请求数量
|
||||||
- **Web 配置界面**:提供供应商和模型配置管理
|
- **Web 配置界面**:提供供应商和模型配置管理
|
||||||
|
|
||||||
@@ -99,23 +100,27 @@ bun dev
|
|||||||
|
|
||||||
### 代理接口(对外部应用)
|
### 代理接口(对外部应用)
|
||||||
|
|
||||||
|
代理接口统一使用 `provider_id/model_name` 格式的模型 ID(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
|
||||||
|
|
||||||
- `POST /v1/chat/completions` - OpenAI Chat Completions API
|
- `POST /v1/chat/completions` - OpenAI Chat Completions API
|
||||||
- `POST /v1/messages` - Anthropic Messages API
|
- `POST /v1/messages` - Anthropic Messages API
|
||||||
|
- `GET /v1/models` - 模型列表(本地数据库聚合,不请求上游)
|
||||||
|
- `GET /v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||||
|
|
||||||
### 管理接口(对前端)
|
### 管理接口(对前端)
|
||||||
|
|
||||||
#### 供应商管理
|
#### 供应商管理
|
||||||
- `GET /api/providers` - 列出所有供应商
|
- `GET /api/providers` - 列出所有供应商
|
||||||
- `POST /api/providers` - 创建供应商
|
- `POST /api/providers` - 创建供应商(`id` 仅限字母、数字、下划线,长度 1-64)
|
||||||
- `GET /api/providers/:id` - 获取供应商
|
- `GET /api/providers/:id` - 获取供应商
|
||||||
- `PUT /api/providers/:id` - 更新供应商
|
- `PUT /api/providers/:id` - 更新供应商(`id` 不可修改)
|
||||||
- `DELETE /api/providers/:id` - 删除供应商
|
- `DELETE /api/providers/:id` - 删除供应商
|
||||||
|
|
||||||
#### 模型管理
|
#### 模型管理
|
||||||
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
|
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
|
||||||
- `POST /api/models` - 创建模型
|
- `POST /api/models` - 创建模型(`id` 由系统自动生成 UUID,`provider_id` + `model_name` 联合唯一)
|
||||||
- `GET /api/models/:id` - 获取模型
|
- `GET /api/models/:id` - 获取模型(响应含 `unified_id` 字段,格式 `provider_id/model_name`)
|
||||||
- `PUT /api/models/:id` - 更新模型
|
- `PUT /api/models/:id` - 更新模型(不可修改 `id`)
|
||||||
- `DELETE /api/models/:id` - 删除模型
|
- `DELETE /api/models/:id` - 删除模型
|
||||||
|
|
||||||
#### 统计查询
|
#### 统计查询
|
||||||
|
|||||||
@@ -108,12 +108,13 @@ backend/
|
|||||||
│ │ ├── logger.go
|
│ │ ├── logger.go
|
||||||
│ │ ├── rotate.go
|
│ │ ├── rotate.go
|
||||||
│ │ └── context.go
|
│ │ └── context.go
|
||||||
|
│ ├── modelid/ # 统一模型 ID 工具包
|
||||||
|
│ │ ├── model_id.go
|
||||||
|
│ │ └── model_id_test.go
|
||||||
│ └── validator/ # 验证器
|
│ └── validator/ # 验证器
|
||||||
│ └── validator.go
|
│ └── validator.go
|
||||||
├── migrations/ # 数据库迁移
|
├── migrations/ # 数据库迁移
|
||||||
│ ├── 20260401000001_initial_schema.sql
|
│ └── 20260421000001_initial_schema.sql
|
||||||
│ ├── 20260401000002_add_indexes.sql
|
|
||||||
│ └── 20260419000001_add_provider_protocol.sql
|
|
||||||
├── tests/ # 集成测试
|
├── tests/ # 集成测试
|
||||||
│ ├── helpers.go
|
│ ├── helpers.go
|
||||||
│ └── integration/
|
│ └── integration/
|
||||||
@@ -292,6 +293,8 @@ GET /anthropic/v1/models
|
|||||||
|
|
||||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
||||||
|
|
||||||
|
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
||||||
|
|
||||||
### 管理接口
|
### 管理接口
|
||||||
|
|
||||||
#### 供应商管理
|
#### 供应商管理
|
||||||
@@ -324,14 +327,30 @@ GET /anthropic/v1/models
|
|||||||
- `PUT /api/models/:id` - 更新模型
|
- `PUT /api/models/:id` - 更新模型
|
||||||
- `DELETE /api/models/:id` - 删除模型
|
- `DELETE /api/models/:id` - 删除模型
|
||||||
|
|
||||||
|
**创建请求**(id 由系统自动生成 UUID):
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"id": "gpt-4",
|
|
||||||
"provider_id": "openai",
|
"provider_id": "openai",
|
||||||
"model_name": "gpt-4"
|
"model_name": "gpt-4"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"provider_id": "openai",
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"unified_id": "openai/gpt-4",
|
||||||
|
"enabled": true,
|
||||||
|
"created_at": "2026-04-21T00:00:00Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**统一模型 ID**:`unified_id` 字段为 `provider_id/model_name` 格式,用于代理请求的 `model` 参数。
|
||||||
|
|
||||||
#### 统计查询
|
#### 统计查询
|
||||||
|
|
||||||
- `GET /api/stats` - 查询统计
|
- `GET /api/stats` - 查询统计
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func main() {
|
|||||||
statsRepo := repository.NewStatsRepository(db)
|
statsRepo := repository.NewStatsRepository(db)
|
||||||
|
|
||||||
// 5. 初始化 service 层
|
// 5. 初始化 service 层
|
||||||
providerService := service.NewProviderService(providerRepo)
|
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||||
statsService := service.NewStatsService(statsRepo)
|
statsService := service.NewStatsService(statsRepo)
|
||||||
|
|||||||
@@ -17,11 +17,11 @@ type Provider struct {
|
|||||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model 模型配置
|
// Model 模型配置(id 为 UUID 自动生成,UNIQUE(provider_id, model_name))
|
||||||
type Model struct {
|
type Model struct {
|
||||||
ID string `gorm:"primaryKey" json:"id"`
|
ID string `gorm:"primaryKey" json:"id"`
|
||||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"`
|
||||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"`
|
||||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ type ProtocolAdapter interface {
|
|||||||
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
|
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
|
||||||
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
|
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
|
||||||
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
|
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
|
||||||
|
|
||||||
|
// 统一模型 ID 相关方法
|
||||||
|
ExtractUnifiedModelID(nativePath string) (string, error)
|
||||||
|
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
|
||||||
|
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||||
|
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AdapterRegistry 适配器注册表接口
|
// AdapterRegistry 适配器注册表接口
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package anthropic
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
@@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||||
func isModelInfoPath(path string) bool {
|
func isModelInfoPath(path string) bool {
|
||||||
if !strings.HasPrefix(path, "/v1/models/") {
|
if !strings.HasPrefix(path, "/v1/models/") {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
suffix := path[len("/v1/models/"):]
|
suffix := path[len("/v1/models/"):]
|
||||||
return suffix != "" && !strings.Contains(suffix, "/")
|
return suffix != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildUrl 根据接口类型构建 URL
|
// BuildUrl 根据接口类型构建 URL
|
||||||
@@ -203,3 +204,74 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
|||||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||||
|
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||||
|
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||||
|
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||||
|
}
|
||||||
|
suffix := nativePath[len("/v1/models/"):]
|
||||||
|
if suffix == "" {
|
||||||
|
return "", fmt.Errorf("路径缺少模型 ID")
|
||||||
|
}
|
||||||
|
return suffix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||||
|
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ifaceType {
|
||||||
|
case conversion.InterfaceTypeChat:
|
||||||
|
raw, exists := m["model"]
|
||||||
|
if !exists {
|
||||||
|
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||||
|
}
|
||||||
|
var current string
|
||||||
|
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||||
|
m["model"], _ = json.Marshal(newModel)
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
return current, rewriteFunc, nil
|
||||||
|
default:
|
||||||
|
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractModelName 从请求体中提取 model 值
|
||||||
|
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||||
|
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||||
|
return model, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||||
|
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||||
|
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rewriteFunc(newModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||||
|
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ifaceType {
|
||||||
|
case conversion.InterfaceTypeChat:
|
||||||
|
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||||
|
m["model"], _ = json.Marshal(newModel)
|
||||||
|
return json.Marshal(m)
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/conversion"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ExtractUnifiedModelID
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractUnifiedModelID(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("standard_path", func(t *testing.T) {
|
||||||
|
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "anthropic/claude-3", id)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multi_segment_path", func(t *testing.T) {
|
||||||
|
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "some/deep/nested/model", id)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single_segment", func(t *testing.T) {
|
||||||
|
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "claude-3", id)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non_model_path", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/messages")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty_suffix", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unrelated_path", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/other")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ExtractModelName (Chat only for Anthropic)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractModelName(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||||
|
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "anthropic/claude-3", model)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("chat_with_max_tokens", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||||
|
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "anthropic/claude-3-opus", model)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no_model_field", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[]}`)
|
||||||
|
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_json", func(t *testing.T) {
|
||||||
|
body := []byte(`{invalid}`)
|
||||||
|
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||||
|
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||||
|
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// RewriteRequestModelName (Chat only for Anthropic)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRewriteRequestModelName(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||||
|
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "claude-3", m["model"])
|
||||||
|
|
||||||
|
msgs, ok := m["messages"]
|
||||||
|
require.True(t, ok)
|
||||||
|
msgsArr, ok := msgs.([]interface{})
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, msgsArr, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
|
||||||
|
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "claude-3", m["model"])
|
||||||
|
assert.Equal(t, 0.7, m["temperature"])
|
||||||
|
|
||||||
|
// max_tokens is encoded as float in JSON numbers
|
||||||
|
maxTokens, ok := m["max_tokens"]
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, float64(1024), maxTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no_model_field", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[]}`)
|
||||||
|
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_json", func(t *testing.T) {
|
||||||
|
body := []byte(`{invalid}`)
|
||||||
|
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||||
|
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// RewriteResponseModelName (Chat only for Anthropic)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRewriteResponseModelName(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat_existing_model", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||||
|
|
||||||
|
// other fields preserved
|
||||||
|
_, hasContent := m["content"]
|
||||||
|
assert.True(t, hasContent)
|
||||||
|
assert.Equal(t, "end_turn", m["stop_reason"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
|
||||||
|
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-3"}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, string(body), string(rewritten))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_json", func(t *testing.T) {
|
||||||
|
body := []byte(`{invalid}`)
|
||||||
|
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ExtractModelName and RewriteRequest consistency
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat_round_trip", func(t *testing.T) {
|
||||||
|
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
|
||||||
|
|
||||||
|
// Extract the unified model ID from the body
|
||||||
|
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "anthropic/claude-3", extracted)
|
||||||
|
|
||||||
|
// Rewrite to the native model name
|
||||||
|
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Extract again from the rewritten body to verify the same location was targeted
|
||||||
|
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "claude-3", afterRewrite)
|
||||||
|
|
||||||
|
// Verify other fields are preserved
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, float64(1024), m["max_tokens"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// isModelInfoPath (additional unified model ID cases)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"simple_model_id", "/v1/models/claude-3", true},
|
||||||
|
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
|
||||||
|
{"models_list", "/v1/models", false},
|
||||||
|
{"models_list_trailing_slash", "/v1/models/", false},
|
||||||
|
{"messages_path", "/v1/messages", false},
|
||||||
|
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -79,11 +79,29 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||||
|
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
|
||||||
|
rewrittenBody := spec.Body
|
||||||
|
|
||||||
|
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
|
||||||
|
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
|
||||||
|
if len(spec.Body) > 0 && provider.ModelName != "" {
|
||||||
|
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||||||
|
zap.String("error", err.Error()),
|
||||||
|
zap.String("interface", string(interfaceType)))
|
||||||
|
rewrittenBody = spec.Body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &HTTPRequestSpec{
|
return &HTTPRequestSpec{
|
||||||
URL: provider.BaseURL + nativePath,
|
URL: provider.BaseURL + nativePath,
|
||||||
Method: spec.Method,
|
Method: spec.Method,
|
||||||
Headers: providerAdapter.BuildHeaders(provider),
|
Headers: providerAdapter.BuildHeaders(provider),
|
||||||
Body: spec.Body,
|
Body: rewrittenBody,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,9 +130,30 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertHttpResponse 转换 HTTP 响应
|
// ConvertHttpResponse 转换 HTTP 响应,modelOverride 用于跨协议场景覆写 model 字段
|
||||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
|
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
|
||||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||||
|
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||||
|
if modelOverride != "" && len(spec.Body) > 0 {
|
||||||
|
adapter, err := e.registry.Get(clientProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return &spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||||
|
zap.String("error", err.Error()),
|
||||||
|
zap.String("interface", string(interfaceType)))
|
||||||
|
return &spec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HTTPResponseSpec{
|
||||||
|
StatusCode: spec.StatusCode,
|
||||||
|
Headers: spec.Headers,
|
||||||
|
Body: rewrittenBody,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
return &spec, nil
|
return &spec, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,7 +166,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
|
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -139,9 +178,17 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateStreamConverter 创建流式转换器
|
// CreateStreamConverter 创建流式转换器,modelOverride 用于跨协议场景覆写 model 字段
|
||||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
|
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
|
||||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||||
|
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||||
|
if modelOverride != "" {
|
||||||
|
adapter, err := e.registry.Get(clientProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return NewPassthroughStreamConverter(), nil
|
||||||
|
}
|
||||||
|
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||||
|
}
|
||||||
return NewPassthroughStreamConverter(), nil
|
return NewPassthroughStreamConverter(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,6 +214,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
|||||||
ctx,
|
ctx,
|
||||||
clientProtocol,
|
clientProtocol,
|
||||||
providerProtocol,
|
providerProtocol,
|
||||||
|
modelOverride,
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,11 +240,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertResponseBody 转换响应体
|
// convertResponseBody 转换响应体,modelOverride 非空时在 canonical 层面覆写 Model 字段
|
||||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||||
switch interfaceType {
|
switch interfaceType {
|
||||||
case InterfaceTypeChat:
|
case InterfaceTypeChat:
|
||||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
|
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||||
case InterfaceTypeModels:
|
case InterfaceTypeModels:
|
||||||
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
||||||
return body, nil
|
return body, nil
|
||||||
@@ -211,12 +259,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
|||||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
|
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||||
case InterfaceTypeRerank:
|
case InterfaceTypeRerank:
|
||||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
|
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||||
default:
|
default:
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
@@ -241,11 +289,14 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
|||||||
return encoded, nil
|
return encoded, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||||||
}
|
}
|
||||||
|
if modelOverride != "" {
|
||||||
|
canonicalResp.Model = modelOverride
|
||||||
|
}
|
||||||
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
||||||
@@ -290,12 +341,15 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
|
|||||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
if modelOverride != "" {
|
||||||
|
resp.Model = modelOverride
|
||||||
|
}
|
||||||
return clientAdapter.EncodeEmbeddingResponse(resp)
|
return clientAdapter.EncodeEmbeddingResponse(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -308,11 +362,14 @@ func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter Prot
|
|||||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
if modelOverride != "" {
|
||||||
|
resp.Model = modelOverride
|
||||||
|
}
|
||||||
return clientAdapter.EncodeRerankResponse(resp)
|
return clientAdapter.EncodeRerankResponse(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
|||||||
|
|
||||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||||
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
|
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
|
||||||
}, "client", "provider", InterfaceTypeChat)
|
}, "client", "provider", InterfaceTypeChat, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, result.StatusCode)
|
assert.Equal(t, 200, result.StatusCode)
|
||||||
assert.Contains(t, string(result.Body), "resp-1")
|
assert.Contains(t, string(result.Body), "resp-1")
|
||||||
@@ -129,7 +129,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
|||||||
_ = engine.RegisterAdapter(providerAdapter)
|
_ = engine.RegisterAdapter(providerAdapter)
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||||
|
|
||||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat)
|
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
|||||||
|
|
||||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
|
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
|
||||||
}, "client", "provider", InterfaceTypeEmbeddings)
|
}, "client", "provider", InterfaceTypeEmbeddings, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, result)
|
assert.NotNil(t, result)
|
||||||
}
|
}
|
||||||
@@ -207,7 +207,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
|||||||
|
|
||||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||||
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
|
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
|
||||||
}, "client", "provider", InterfaceTypeRerank)
|
}, "client", "provider", InterfaceTypeRerank, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, result)
|
assert.NotNil(t, result)
|
||||||
}
|
}
|
||||||
@@ -242,7 +242,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
|||||||
|
|
||||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
|
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
|
||||||
}, "client", "provider", InterfaceTypeModels)
|
}, "client", "provider", InterfaceTypeModels, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, result)
|
assert.NotNil(t, result)
|
||||||
}
|
}
|
||||||
@@ -259,7 +259,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
|||||||
|
|
||||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||||
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
|
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
|
||||||
}, "client", "provider", InterfaceTypeModelInfo)
|
}, "client", "provider", InterfaceTypeModelInfo, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, result)
|
assert.NotNil(t, result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,16 +13,18 @@ import (
|
|||||||
|
|
||||||
// mockProtocolAdapter 模拟协议适配器
|
// mockProtocolAdapter 模拟协议适配器
|
||||||
type mockProtocolAdapter struct {
|
type mockProtocolAdapter struct {
|
||||||
protocolName string
|
protocolName string
|
||||||
passthrough bool
|
passthrough bool
|
||||||
ifaceType InterfaceType
|
ifaceType InterfaceType
|
||||||
supportsIface map[InterfaceType]bool
|
supportsIface map[InterfaceType]bool
|
||||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||||
streamDecoderFn func() StreamDecoder
|
streamDecoderFn func() StreamDecoder
|
||||||
streamEncoderFn func() StreamEncoder
|
streamEncoderFn func() StreamEncoder
|
||||||
|
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||||
|
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||||
@@ -155,6 +157,28 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera
|
|||||||
return json.Marshal(resp)
|
return json.Marshal(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||||
|
if m.rewriteReqFn != nil {
|
||||||
|
return m.rewriteReqFn(body, newModel, ifaceType)
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||||
|
if m.rewriteRespFn != nil {
|
||||||
|
return m.rewriteRespFn(body, newModel, ifaceType)
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
// noopStreamDecoder 空流式解码器
|
// noopStreamDecoder 空流式解码器
|
||||||
type noopStreamDecoder struct{}
|
type noopStreamDecoder struct{}
|
||||||
|
|
||||||
@@ -309,7 +333,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
|||||||
Body: []byte(`{"id":"123"}`),
|
Body: []byte(`{"id":"123"}`),
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat)
|
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 200, result.StatusCode)
|
assert.Equal(t, 200, result.StatusCode)
|
||||||
assert.Equal(t, spec.Body, result.Body)
|
assert.Equal(t, spec.Body, result.Body)
|
||||||
@@ -320,7 +344,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
|||||||
engine := NewConversionEngine(registry, nil)
|
engine := NewConversionEngine(registry, nil)
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||||
|
|
||||||
converter, err := engine.CreateStreamConverter("openai", "openai")
|
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, ok := converter.(*PassthroughStreamConverter)
|
_, ok := converter.(*PassthroughStreamConverter)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -332,7 +356,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
|
|||||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||||
|
|
||||||
converter, err := engine.CreateStreamConverter("client", "provider")
|
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, ok := converter.(*CanonicalStreamConverter)
|
_, ok := converter.(*CanonicalStreamConverter)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -380,3 +404,230 @@ func TestRegistry_GetNonExistent(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "未找到适配器")
|
assert.Contains(t, err.Error(), "未找到适配器")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ modelOverride 测试 ============
|
||||||
|
|
||||||
|
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||||
|
registry := NewMemoryRegistry()
|
||||||
|
engine := NewConversionEngine(registry, nil)
|
||||||
|
|
||||||
|
clientAdapter := newMockAdapter("client", false)
|
||||||
|
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||||
|
return json.Marshal(map[string]any{"model": resp.Model})
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(clientAdapter)
|
||||||
|
|
||||||
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
|
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||||
|
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(providerAdapter)
|
||||||
|
|
||||||
|
spec := HTTPResponseSpec{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: []byte(`{"model":"native-model"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||||
|
assert.Equal(t, "provider/gpt-4", resp["model"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||||
|
registry := NewMemoryRegistry()
|
||||||
|
engine := NewConversionEngine(registry, nil)
|
||||||
|
|
||||||
|
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
||||||
|
openaiAdapter := newMockAdapter("openai", true)
|
||||||
|
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m["model"], _ = json.Marshal(newModel)
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(openaiAdapter)
|
||||||
|
|
||||||
|
spec := HTTPResponseSpec{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||||
|
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||||
|
assert.Equal(t, "resp-1", resp["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||||
|
registry := NewMemoryRegistry()
|
||||||
|
engine := NewConversionEngine(registry, nil)
|
||||||
|
|
||||||
|
openaiAdapter := newMockAdapter("openai", true)
|
||||||
|
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m["model"], _ = json.Marshal(newModel)
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(openaiAdapter)
|
||||||
|
|
||||||
|
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// 验证 chunk 改写
|
||||||
|
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||||
|
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||||
|
registry := NewMemoryRegistry()
|
||||||
|
engine := NewConversionEngine(registry, nil)
|
||||||
|
|
||||||
|
// provider adapter 解码出含 model 的流式事件
|
||||||
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
|
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||||
|
return &engineTestStreamDecoder{
|
||||||
|
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||||
|
return []canonical.CanonicalStreamEvent{
|
||||||
|
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||||
|
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
|
||||||
|
canonical.NewMessageStopEvent(),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(providerAdapter)
|
||||||
|
|
||||||
|
// client adapter 编码时输出 model 字段
|
||||||
|
clientAdapter := newMockAdapter("client", false)
|
||||||
|
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||||
|
return &engineTestStreamEncoder{
|
||||||
|
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||||
|
if event.Message != nil {
|
||||||
|
data, _ := json.Marshal(map[string]string{
|
||||||
|
"type": string(event.Type),
|
||||||
|
"model": event.Message.Model,
|
||||||
|
})
|
||||||
|
return [][]byte{data}
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
|
||||||
|
return [][]byte{data}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(clientAdapter)
|
||||||
|
|
||||||
|
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验证类型是 CanonicalStreamConverter
|
||||||
|
_, ok := converter.(*CanonicalStreamConverter)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// 处理一个 chunk,验证 model 被覆写为统一模型 ID
|
||||||
|
chunks := converter.ProcessChunk([]byte("raw"))
|
||||||
|
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
|
||||||
|
|
||||||
|
var startEvent map[string]string
|
||||||
|
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
|
||||||
|
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
||||||
|
registry := NewMemoryRegistry()
|
||||||
|
engine := NewConversionEngine(registry, nil)
|
||||||
|
|
||||||
|
providerAdapter := newMockAdapter("provider", false)
|
||||||
|
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||||
|
return &engineTestStreamDecoder{
|
||||||
|
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||||
|
return []canonical.CanonicalStreamEvent{
|
||||||
|
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(providerAdapter)
|
||||||
|
|
||||||
|
clientAdapter := newMockAdapter("client", false)
|
||||||
|
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||||
|
return &engineTestStreamEncoder{
|
||||||
|
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||||
|
if event.Message != nil {
|
||||||
|
data, _ := json.Marshal(map[string]string{
|
||||||
|
"model": event.Message.Model,
|
||||||
|
})
|
||||||
|
return [][]byte{data}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = engine.RegisterAdapter(clientAdapter)
|
||||||
|
|
||||||
|
// modelOverride 为空,不应覆写
|
||||||
|
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
chunks := converter.ProcessChunk([]byte("raw"))
|
||||||
|
require.Len(t, chunks, 1)
|
||||||
|
|
||||||
|
var resp map[string]string
|
||||||
|
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||||
|
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
|
||||||
|
}
|
||||||
|
|
||||||
|
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test)
|
||||||
|
type engineTestStreamDecoder struct {
|
||||||
|
processFn func([]byte) []canonical.CanonicalStreamEvent
|
||||||
|
flushFn func() []canonical.CanonicalStreamEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
|
||||||
|
if d.processFn != nil {
|
||||||
|
return d.processFn(raw)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||||
|
if d.flushFn != nil {
|
||||||
|
return d.flushFn()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test)
|
||||||
|
type engineTestStreamEncoder struct {
|
||||||
|
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
|
||||||
|
flushFn func() [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||||
|
if e.encodeFn != nil {
|
||||||
|
return e.encodeFn(event)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||||
|
if e.flushFn != nil {
|
||||||
|
return e.flushFn()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
@@ -43,13 +44,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||||
func isModelInfoPath(path string) bool {
|
func isModelInfoPath(path string) bool {
|
||||||
if !strings.HasPrefix(path, "/v1/models/") {
|
if !strings.HasPrefix(path, "/v1/models/") {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
suffix := path[len("/v1/models/"):]
|
suffix := path[len("/v1/models/"):]
|
||||||
return suffix != "" && !strings.Contains(suffix, "/")
|
return suffix != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildUrl 根据接口类型构建 URL
|
// BuildUrl 根据接口类型构建 URL
|
||||||
@@ -216,3 +217,80 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
|||||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||||
return encodeRerankResponse(resp)
|
return encodeRerankResponse(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||||
|
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||||
|
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||||
|
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||||
|
}
|
||||||
|
suffix := nativePath[len("/v1/models/"):]
|
||||||
|
if suffix == "" {
|
||||||
|
return "", fmt.Errorf("路径缺少模型 ID")
|
||||||
|
}
|
||||||
|
return suffix, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||||
|
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ifaceType {
|
||||||
|
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||||
|
raw, exists := m["model"]
|
||||||
|
if !exists {
|
||||||
|
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||||
|
}
|
||||||
|
var current string
|
||||||
|
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||||
|
m["model"], _ = json.Marshal(newModel)
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
return current, rewriteFunc, nil
|
||||||
|
default:
|
||||||
|
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractModelName 从请求体中提取 model 值
|
||||||
|
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||||
|
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||||
|
return model, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||||
|
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||||
|
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rewriteFunc(newModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||||
|
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||||
|
var m map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(body, &m); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ifaceType {
|
||||||
|
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||||
|
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||||
|
m["model"], _ = json.Marshal(newModel)
|
||||||
|
return json.Marshal(m)
|
||||||
|
case conversion.InterfaceTypeRerank:
|
||||||
|
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||||
|
if _, exists := m["model"]; exists {
|
||||||
|
m["model"], _ = json.Marshal(newModel)
|
||||||
|
}
|
||||||
|
return json.Marshal(m)
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ func TestIsModelInfoPath(t *testing.T) {
|
|||||||
{"model_info", "/v1/models/gpt-4", true},
|
{"model_info", "/v1/models/gpt-4", true},
|
||||||
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
|
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
|
||||||
{"models_list", "/v1/models", false},
|
{"models_list", "/v1/models", false},
|
||||||
{"nested_path", "/v1/models/gpt-4/versions", false},
|
{"nested_path", "/v1/models/gpt-4/versions", true},
|
||||||
{"empty_suffix", "/v1/models/", false},
|
{"empty_suffix", "/v1/models/", false},
|
||||||
{"unrelated", "/v1/chat/completions", false},
|
{"unrelated", "/v1/chat/completions", false},
|
||||||
{"partial_prefix", "/v1/model", false},
|
{"partial_prefix", "/v1/model", false},
|
||||||
|
|||||||
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"nex/backend/internal/conversion"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ExtractUnifiedModelID
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractUnifiedModelID(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("standard_path", func(t *testing.T) {
|
||||||
|
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/gpt-4", id)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multi_segment_path", func(t *testing.T) {
|
||||||
|
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single_segment", func(t *testing.T) {
|
||||||
|
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4", id)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non_model_path", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty_suffix", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unrelated_path", func(t *testing.T) {
|
||||||
|
_, err := a.ExtractUnifiedModelID("/v1/other")
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ExtractModelName
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractModelName(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||||
|
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/gpt-4", model)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("embedding", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||||
|
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/text-embedding", model)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rerank", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||||
|
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/rerank", model)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no_model_field", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[]}`)
|
||||||
|
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_json", func(t *testing.T) {
|
||||||
|
body := []byte(`{invalid}`)
|
||||||
|
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||||
|
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// RewriteRequestModelName
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRewriteRequestModelName(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||||
|
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "gpt-4", m["model"])
|
||||||
|
|
||||||
|
// messages field preserved
|
||||||
|
msgs, ok := m["messages"]
|
||||||
|
require.True(t, ok)
|
||||||
|
msgsArr, ok := msgs.([]interface{})
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, msgsArr, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
|
||||||
|
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "gpt-4", m["model"])
|
||||||
|
assert.Equal(t, 0.7, m["temperature"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("embedding", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||||
|
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "text-embedding", m["model"])
|
||||||
|
assert.Equal(t, "hello", m["input"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rerank", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||||
|
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "rerank", m["model"])
|
||||||
|
assert.Equal(t, "test", m["query"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no_model_field", func(t *testing.T) {
|
||||||
|
body := []byte(`{"messages":[]}`)
|
||||||
|
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_json", func(t *testing.T) {
|
||||||
|
body := []byte(`{invalid}`)
|
||||||
|
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||||
|
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// RewriteResponseModelName
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRewriteResponseModelName(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat_existing_model", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-4","choices":[]}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||||
|
|
||||||
|
choices, ok := m["choices"]
|
||||||
|
require.True(t, ok)
|
||||||
|
choicesArr, ok := choices.([]interface{})
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, choicesArr, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("chat_without_model_field", func(t *testing.T) {
|
||||||
|
body := []byte(`{"choices":[]}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||||
|
|
||||||
|
choices, ok := m["choices"]
|
||||||
|
require.True(t, ok)
|
||||||
|
choicesArr, ok := choices.([]interface{})
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, choicesArr, 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rerank_existing_model", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"rerank","results":[]}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "openai/rerank", m["model"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
|
||||||
|
body := []byte(`{"results":[]}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
_, hasModel := m["model"]
|
||||||
|
assert.False(t, hasModel, "rerank response without model field should not have one added")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("embedding_existing_model", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"text-embedding","data":[]}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
|
||||||
|
body := []byte(`{"data":[]}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-4"}`)
|
||||||
|
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, string(body), string(rewritten))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid_json", func(t *testing.T) {
|
||||||
|
body := []byte(`{invalid}`)
|
||||||
|
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ExtractModelName and RewriteRequest consistency
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||||
|
a := NewAdapter()
|
||||||
|
|
||||||
|
t.Run("chat_round_trip", func(t *testing.T) {
|
||||||
|
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
|
||||||
|
|
||||||
|
// Extract the unified model ID from the body
|
||||||
|
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/gpt-4", extracted)
|
||||||
|
|
||||||
|
// Rewrite to the native model name
|
||||||
|
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Extract again from the rewritten body to verify the same location was targeted
|
||||||
|
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4", afterRewrite)
|
||||||
|
|
||||||
|
// Verify other fields are preserved
|
||||||
|
var m map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||||
|
assert.Equal(t, 0.7, m["temperature"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("embedding_round_trip", func(t *testing.T) {
|
||||||
|
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||||
|
|
||||||
|
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/text-embedding", extracted)
|
||||||
|
|
||||||
|
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "text-embedding", afterRewrite)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rerank_round_trip", func(t *testing.T) {
|
||||||
|
original := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||||
|
|
||||||
|
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai/rerank", extracted)
|
||||||
|
|
||||||
|
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "rerank", afterRewrite)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// isModelInfoPath (additional unified model ID cases)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"simple_model_id", "/v1/models/gpt-4", true},
|
||||||
|
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
|
||||||
|
{"models_list", "/v1/models", false},
|
||||||
|
{"models_list_trailing_slash", "/v1/models/", false},
|
||||||
|
{"chat_completions", "/v1/chat/completions", false},
|
||||||
|
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -38,14 +38,52 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||||
|
// 逐 chunk 改写 model 字段
|
||||||
|
type SmartPassthroughStreamConverter struct {
|
||||||
|
adapter ProtocolAdapter
|
||||||
|
modelOverride string
|
||||||
|
interfaceType InterfaceType
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||||
|
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
|
||||||
|
return &SmartPassthroughStreamConverter{
|
||||||
|
adapter: adapter,
|
||||||
|
modelOverride: modelOverride,
|
||||||
|
interfaceType: interfaceType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessChunk 改写 chunk 中的 model 字段
|
||||||
|
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||||
|
if len(rawChunk) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
|
||||||
|
if err != nil {
|
||||||
|
// 改写失败,返回原始 chunk
|
||||||
|
return [][]byte{rawChunk}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [][]byte{rewrittenChunk}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush 无缓冲数据
|
||||||
|
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||||
type CanonicalStreamConverter struct {
|
type CanonicalStreamConverter struct {
|
||||||
decoder StreamDecoder
|
decoder StreamDecoder
|
||||||
encoder StreamEncoder
|
encoder StreamEncoder
|
||||||
chain *MiddlewareChain
|
chain *MiddlewareChain
|
||||||
ctx ConversionContext
|
ctx ConversionContext
|
||||||
clientProtocol string
|
clientProtocol string
|
||||||
providerProtocol string
|
providerProtocol string
|
||||||
|
modelOverride string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCanonicalStreamConverter 创建规范流式转换器
|
// NewCanonicalStreamConverter 创建规范流式转换器
|
||||||
@@ -57,18 +95,19 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
||||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter {
|
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
|
||||||
return &CanonicalStreamConverter{
|
return &CanonicalStreamConverter{
|
||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
encoder: encoder,
|
encoder: encoder,
|
||||||
chain: chain,
|
chain: chain,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
clientProtocol: clientProtocol,
|
clientProtocol: clientProtocol,
|
||||||
providerProtocol: providerProtocol,
|
providerProtocol: providerProtocol,
|
||||||
|
modelOverride: modelOverride,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessChunk 解码 → 中间件 → 编码管道
|
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
|
||||||
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||||
events := c.decoder.ProcessChunk(rawChunk)
|
events := c.decoder.ProcessChunk(rawChunk)
|
||||||
var result [][]byte
|
var result [][]byte
|
||||||
@@ -80,6 +119,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
|||||||
}
|
}
|
||||||
events[i] = *processed
|
events[i] = *processed
|
||||||
}
|
}
|
||||||
|
c.applyModelOverride(&events[i])
|
||||||
chunks := c.encoder.EncodeEvent(events[i])
|
chunks := c.encoder.EncodeEvent(events[i])
|
||||||
result = append(result, chunks...)
|
result = append(result, chunks...)
|
||||||
}
|
}
|
||||||
@@ -98,6 +138,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
|||||||
}
|
}
|
||||||
events[i] = *processed
|
events[i] = *processed
|
||||||
}
|
}
|
||||||
|
c.applyModelOverride(&events[i])
|
||||||
chunks := c.encoder.EncodeEvent(events[i])
|
chunks := c.encoder.EncodeEvent(events[i])
|
||||||
result = append(result, chunks...)
|
result = append(result, chunks...)
|
||||||
}
|
}
|
||||||
@@ -105,3 +146,10 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
|||||||
result = append(result, encoderChunks...)
|
result = append(result, encoderChunks...)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
|
||||||
|
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
|
||||||
|
if c.modelOverride != "" && event.Message != nil {
|
||||||
|
event.Message.Model = c.modelOverride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
|
|||||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||||
ctx := NewConversionContext(InterfaceTypeChat)
|
ctx := NewConversionContext(InterfaceTypeChat)
|
||||||
|
|
||||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||||
result := converter.ProcessChunk([]byte("raw"))
|
result := converter.ProcessChunk([]byte("raw"))
|
||||||
|
|
||||||
assert.Len(t, result, 1)
|
assert.Len(t, result, 1)
|
||||||
@@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
|
|||||||
chain.Use(&errorMiddleware{})
|
chain.Use(&errorMiddleware{})
|
||||||
ctx := NewConversionContext(InterfaceTypeChat)
|
ctx := NewConversionContext(InterfaceTypeChat)
|
||||||
|
|
||||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||||
result := converter.ProcessChunk([]byte("raw"))
|
result := converter.ProcessChunk([]byte("raw"))
|
||||||
|
|
||||||
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
|
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
|
||||||
@@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
|
|||||||
chain.Use(&errorMiddleware{})
|
chain.Use(&errorMiddleware{})
|
||||||
ctx := NewConversionContext(InterfaceTypeChat)
|
ctx := NewConversionContext(InterfaceTypeChat)
|
||||||
|
|
||||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||||
result := converter.Flush()
|
result := converter.Flush()
|
||||||
|
|
||||||
assert.Len(t, result, 1)
|
assert.Len(t, result, 1)
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package domain
|
package domain
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
// Model 模型领域模型
|
"nex/backend/pkg/modelid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Model 模型领域模型(id 为 UUID 自动生成)
|
||||||
type Model struct {
|
type Model struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
ProviderID string `json:"provider_id"`
|
ProviderID string `json:"provider_id"`
|
||||||
@@ -10,3 +14,8 @@ type Model struct {
|
|||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnifiedModelID 返回统一模型 ID(格式:provider_id/model_name)
|
||||||
|
func (m *Model) UnifiedModelID() string {
|
||||||
|
return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName)
|
||||||
|
}
|
||||||
|
|||||||
@@ -113,7 +113,6 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
|
|||||||
h := NewModelHandler(&mockModelService{})
|
h := NewModelHandler(&mockModelService{})
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]string{
|
body, _ := json.Marshal(map[string]string{
|
||||||
"id": "m1",
|
|
||||||
"provider_id": "p1",
|
"provider_id": "p1",
|
||||||
"model_name": "gpt-4",
|
"model_name": "gpt-4",
|
||||||
})
|
})
|
||||||
@@ -127,7 +126,7 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
|
|||||||
|
|
||||||
var result domain.Model
|
var result domain.Model
|
||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||||
assert.Equal(t, "m1", result.ID)
|
assert.NotEmpty(t, result.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelHandler_GetModel(t *testing.T) {
|
func TestModelHandler_GetModel(t *testing.T) {
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/gorm"
|
|
||||||
|
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
"nex/backend/internal/provider"
|
"nex/backend/internal/provider"
|
||||||
@@ -31,7 +30,7 @@ type mockRoutingService struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
|
func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||||
return m.result, m.err
|
return m.result, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,6 +56,14 @@ type mockProviderService struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||||
return m.provider, m.err
|
return m.provider, m.err
|
||||||
@@ -73,13 +80,21 @@ type mockModelService struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
|
func (m *mockModelService) Create(model *domain.Model) error {
|
||||||
|
if m.err == nil {
|
||||||
|
model.ID = "mock-uuid-1234"
|
||||||
|
}
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||||
return m.model, m.err
|
return m.model, m.err
|
||||||
}
|
}
|
||||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||||
return m.models, m.err
|
return m.models, m.err
|
||||||
}
|
}
|
||||||
|
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
|
||||||
|
return []domain.Model{}, nil
|
||||||
|
}
|
||||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||||
return m.err
|
return m.err
|
||||||
}
|
}
|
||||||
@@ -163,8 +178,8 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
|||||||
func TestModelHandler_ListModels(t *testing.T) {
|
func TestModelHandler_ListModels(t *testing.T) {
|
||||||
h := NewModelHandler(&mockModelService{
|
h := NewModelHandler(&mockModelService{
|
||||||
models: []domain.Model{
|
models: []domain.Model{
|
||||||
{ID: "m1", ModelName: "gpt-4"},
|
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||||
{ID: "m2", ModelName: "gpt-3.5"},
|
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -174,6 +189,72 @@ func TestModelHandler_ListModels(t *testing.T) {
|
|||||||
|
|
||||||
h.ListModels(c)
|
h.ListModels(c)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var result []modelResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||||
|
require.Len(t, result, 2)
|
||||||
|
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
|
||||||
|
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||||
|
h := NewModelHandler(&mockModelService{
|
||||||
|
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||||
|
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
|
||||||
|
|
||||||
|
h.GetModel(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var result modelResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||||
|
assert.Equal(t, "m1", result.ID)
|
||||||
|
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||||
|
h := NewModelHandler(&mockModelService{})
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
"provider_id": "openai",
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.CreateModel(c)
|
||||||
|
assert.Equal(t, 201, w.Code)
|
||||||
|
|
||||||
|
var result modelResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||||
|
assert.Equal(t, "mock-uuid-1234", result.ID)
|
||||||
|
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||||
|
h := NewModelHandler(&mockModelService{
|
||||||
|
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"},
|
||||||
|
})
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||||
|
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
|
||||||
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
h.UpdateModel(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var result modelResponse
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||||
|
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Stats Handler 测试 ============
|
// ============ Stats Handler 测试 ============
|
||||||
@@ -256,7 +337,7 @@ func formatMapErrors(errs map[string]string) string {
|
|||||||
|
|
||||||
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||||
h := NewProviderHandler(&mockProviderService{
|
h := NewProviderHandler(&mockProviderService{
|
||||||
err: gorm.ErrDuplicatedKey,
|
err: appErrors.ErrConflict,
|
||||||
})
|
})
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]string{
|
body, _ := json.Marshal(map[string]string{
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -22,23 +23,35 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
|
|||||||
return &ModelHandler{modelService: modelService}
|
return &ModelHandler{modelService: modelService}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// modelResponse 模型响应 DTO,扩展 unified_id 字段
|
||||||
|
type modelResponse struct {
|
||||||
|
domain.Model
|
||||||
|
UnifiedModelID string `json:"unified_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// newModelResponse 从 domain.Model 构造响应 DTO
|
||||||
|
func newModelResponse(m *domain.Model) modelResponse {
|
||||||
|
return modelResponse{
|
||||||
|
Model: *m,
|
||||||
|
UnifiedModelID: m.UnifiedModelID(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CreateModel 创建模型
|
// CreateModel 创建模型
|
||||||
func (h *ModelHandler) CreateModel(c *gin.Context) {
|
func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||||
var req struct {
|
var req struct {
|
||||||
ID string `json:"id" binding:"required"`
|
|
||||||
ProviderID string `json:"provider_id" binding:"required"`
|
ProviderID string `json:"provider_id" binding:"required"`
|
||||||
ModelName string `json:"model_name" binding:"required"`
|
ModelName string `json:"model_name" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": "缺少必需字段: id, provider_id, model_name",
|
"error": "缺少必需字段: provider_id, model_name",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model := &domain.Model{
|
model := &domain.Model{
|
||||||
ID: req.ID,
|
|
||||||
ProviderID: req.ProviderID,
|
ProviderID: req.ProviderID,
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
}
|
}
|
||||||
@@ -51,11 +64,18 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err == appErrors.ErrDuplicateModel {
|
||||||
|
c.JSON(http.StatusConflict, gin.H{
|
||||||
|
"error": "同一供应商下模型名称已存在",
|
||||||
|
"code": appErrors.ErrDuplicateModel.Code,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
writeError(c, err)
|
writeError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, model)
|
c.JSON(http.StatusCreated, newModelResponse(model))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListModels 列出模型
|
// ListModels 列出模型
|
||||||
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, models)
|
resp := make([]modelResponse, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
resp[i] = newModelResponse(&m)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModel 获取模型
|
// GetModel 获取模型
|
||||||
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, model)
|
c.JSON(http.StatusOK, newModelResponse(model))
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateModel 更新模型
|
// UpdateModel 更新模型
|
||||||
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.modelService.Update(id, req)
|
err := h.modelService.Update(id, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == gorm.ErrRecordNotFound {
|
if errors.Is(err, appErrors.ErrModelNotFound) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
"error": "模型未找到",
|
"error": "模型未找到",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err == appErrors.ErrProviderNotFound {
|
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": "供应商不存在",
|
"error": "供应商不存在",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||||
|
c.JSON(http.StatusConflict, gin.H{
|
||||||
|
"error": appErrors.ErrDuplicateModel.Message,
|
||||||
|
"code": appErrors.ErrDuplicateModel.Code,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
writeError(c, err)
|
writeError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, model)
|
c.JSON(http.StatusOK, newModelResponse(model))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteModel 删除模型
|
// DeleteModel 删除模型
|
||||||
|
|||||||
@@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
|||||||
|
|
||||||
err := h.providerService.Create(provider)
|
err := h.providerService.Create(provider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
if err == appErrors.ErrInvalidProviderID {
|
||||||
c.JSON(http.StatusConflict, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": "供应商 ID 已存在",
|
"error": appErrors.ErrInvalidProviderID.Message,
|
||||||
|
"code": appErrors.ErrInvalidProviderID.Code,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -119,6 +120,13 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, appErrors.ErrImmutableField) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": appErrors.ErrImmutableField.Message,
|
||||||
|
"code": appErrors.ErrImmutableField.Code,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
writeError(c, err)
|
writeError(c, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
|
"nex/backend/internal/conversion/canonical"
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
"nex/backend/internal/provider"
|
"nex/backend/internal/provider"
|
||||||
"nex/backend/internal/service"
|
"nex/backend/internal/service"
|
||||||
|
"nex/backend/pkg/modelid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProxyHandler 统一代理处理器
|
// ProxyHandler 统一代理处理器
|
||||||
@@ -54,6 +56,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
nativePath := "/v1/" + path
|
nativePath := "/v1/" + path
|
||||||
|
|
||||||
|
// 获取 client adapter
|
||||||
|
registry := h.engine.GetRegistry()
|
||||||
|
clientAdapter, err := registry.Get(clientProtocol)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检测接口类型
|
||||||
|
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||||
|
|
||||||
|
// 处理 Models 接口:本地聚合
|
||||||
|
if ifaceType == conversion.InterfaceTypeModels {
|
||||||
|
h.handleModelsList(c, clientAdapter)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 ModelInfo 接口:本地查询
|
||||||
|
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||||||
|
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 读取请求体
|
// 读取请求体
|
||||||
body, err := io.ReadAll(c.Request.Body)
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -61,10 +91,17 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析 model 名称(从 JSON body 中提取,GET 请求无 body)
|
// 解析统一模型 ID(使用 adapter.ExtractModelName)
|
||||||
modelName := ""
|
var providerID, modelName string
|
||||||
if len(body) > 0 {
|
if len(body) > 0 {
|
||||||
modelName = extractModelName(body)
|
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||||
|
if err == nil && unifiedID != "" {
|
||||||
|
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||||
|
if err == nil {
|
||||||
|
providerID = pid
|
||||||
|
modelName = mn
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建输入 HTTPRequestSpec
|
// 构建输入 HTTPRequestSpec
|
||||||
@@ -76,7 +113,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 路由
|
// 路由
|
||||||
routeResult, err := h.routingService.Route(modelName)
|
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// GET 请求或无法提取 model 时,直接转发到上游
|
// GET 请求或无法提取 model 时,直接转发到上游
|
||||||
if len(body) == 0 || modelName == "" {
|
if len(body) == 0 || modelName == "" {
|
||||||
@@ -94,24 +131,30 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构建 TargetProvider
|
// 构建 TargetProvider
|
||||||
|
// 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体
|
||||||
|
// 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名)
|
||||||
|
// 跨协议:全量转换时 ModelName 会被编码到请求体中
|
||||||
targetProvider := conversion.NewTargetProvider(
|
targetProvider := conversion.NewTargetProvider(
|
||||||
routeResult.Provider.BaseURL,
|
routeResult.Provider.BaseURL,
|
||||||
routeResult.Provider.APIKey,
|
routeResult.Provider.APIKey,
|
||||||
routeResult.Model.ModelName,
|
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||||
)
|
)
|
||||||
|
|
||||||
// 判断是否流式
|
// 判断是否流式
|
||||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||||
|
|
||||||
|
// 计算统一模型 ID(用于响应覆写)
|
||||||
|
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||||||
|
|
||||||
if isStream {
|
if isStream {
|
||||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||||
} else {
|
} else {
|
||||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleNonStream 处理非流式请求
|
// handleNonStream 处理非流式请求
|
||||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||||
// 转换请求
|
// 转换请求
|
||||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -128,9 +171,8 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换响应
|
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||||
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
|
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeConversionError(c, err, clientProtocol)
|
||||||
@@ -153,7 +195,7 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleStream 处理流式请求
|
// handleStream 处理流式请求
|
||||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||||
// 转换请求
|
// 转换请求
|
||||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -161,8 +203,8 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建流式转换器
|
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
|
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeConversionError(c, err, clientProtocol)
|
||||||
return
|
return
|
||||||
@@ -224,6 +266,79 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
|
|||||||
return req.Stream
|
return req.Stream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleModelsList 处理 GET /v1/models 本地聚合
|
||||||
|
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
|
||||||
|
// 从数据库查询所有启用的模型
|
||||||
|
models, err := h.providerService.ListEnabledModels()
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 CanonicalModelList
|
||||||
|
modelList := &canonical.CanonicalModelList{
|
||||||
|
Models: make([]canonical.CanonicalModel, 0, len(models)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range models {
|
||||||
|
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
|
||||||
|
ID: m.UnifiedModelID(),
|
||||||
|
Name: m.ModelName,
|
||||||
|
Created: m.CreatedAt.Unix(),
|
||||||
|
OwnedBy: m.ProviderID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 adapter 编码返回
|
||||||
|
body, err := adapter.EncodeModelsResponse(modelList)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Data(http.StatusOK, "application/json", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
|
||||||
|
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
|
||||||
|
// 解析统一模型 ID
|
||||||
|
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": "无效的统一模型 ID 格式",
|
||||||
|
"code": "INVALID_MODEL_ID",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从数据库查询模型
|
||||||
|
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 CanonicalModelInfo
|
||||||
|
modelInfo := &canonical.CanonicalModelInfo{
|
||||||
|
ID: model.UnifiedModelID(),
|
||||||
|
Name: model.ModelName,
|
||||||
|
Created: model.CreatedAt.Unix(),
|
||||||
|
OwnedBy: model.ProviderID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 adapter 编码返回
|
||||||
|
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Data(http.StatusOK, "application/json", body)
|
||||||
|
}
|
||||||
|
|
||||||
// writeConversionError 写入转换错误
|
// writeConversionError 写入转换错误
|
||||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
if convErr, ok := err.(*conversion.ConversionError); ok {
|
||||||
@@ -292,7 +407,7 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
|
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.writeConversionError(c, err, clientProtocol)
|
h.writeConversionError(c, err, clientProtocol)
|
||||||
return
|
return
|
||||||
@@ -307,17 +422,6 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
|||||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractModelName 从 JSON body 中提取 model
|
|
||||||
func extractModelName(body []byte) string {
|
|
||||||
var req struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return req.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractHeaders 从 Gin context 提取请求头
|
// extractHeaders 从 Gin context 提取请求头
|
||||||
func extractHeaders(c *gin.Context) map[string]string {
|
func extractHeaders(c *gin.Context) map[string]string {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
|
|||||||
@@ -60,13 +60,23 @@ type mockProxyRoutingService struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProxyRoutingService) Route(modelName string) (*domain.RouteResult, error) {
|
func (m *mockProxyRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||||
return m.result, m.err
|
return m.result, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockProxyProviderService struct {
|
type mockProxyProviderService struct {
|
||||||
providers []domain.Provider
|
providers []domain.Provider
|
||||||
err error
|
err error
|
||||||
|
enabledModels []domain.Model
|
||||||
|
modelByProvName *domain.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyProviderService) ListEnabledModels() ([]domain.Model, error) {
|
||||||
|
return m.enabledModels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||||
|
return m.modelByProvName, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil }
|
func (m *mockProxyProviderService) Create(p *domain.Provider) error { return nil }
|
||||||
@@ -319,7 +329,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
|
|||||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
h.HandleProxy(c)
|
h.HandleProxy(c)
|
||||||
assert.Equal(t, 404, w.Code)
|
// Models 接口现在本地聚合,返回空列表 200
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractHeaders(t *testing.T) {
|
func TestExtractHeaders(t *testing.T) {
|
||||||
@@ -716,58 +727,6 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
|
|||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ extractModelName 测试 ============
|
|
||||||
|
|
||||||
func TestExtractModelName(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
body []byte
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid model",
|
|
||||||
body: []byte(`{"model": "gpt-4", "messages": []}`),
|
|
||||||
expected: "gpt-4",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty body",
|
|
||||||
body: []byte(`{}`),
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid json",
|
|
||||||
body: []byte(`{invalid}`),
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nested structure",
|
|
||||||
body: []byte(`{"model": "claude-3", "messages": [{"role": "user", "content": "hello"}]}`),
|
|
||||||
expected: "claude-3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model with special chars",
|
|
||||||
body: []byte(`{"model": "gpt-4-0125-preview", "stream": true}`),
|
|
||||||
expected: "gpt-4-0125-preview",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty body bytes",
|
|
||||||
body: []byte{},
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model is null",
|
|
||||||
body: []byte(`{"model": null}`),
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := extractModelName(tt.body)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ isStreamRequest 测试 ============
|
// ============ isStreamRequest 测试 ============
|
||||||
|
|
||||||
func TestIsStreamRequest(t *testing.T) {
|
func TestIsStreamRequest(t *testing.T) {
|
||||||
@@ -831,3 +790,270 @@ func TestIsStreamRequest(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ Models / ModelInfo 本地聚合测试 ============
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
providerSvc := &mockProxyProviderService{
|
||||||
|
enabledModels: []domain.Model{
|
||||||
|
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||||
|
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||||
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
data, ok := resp["data"].([]interface{})
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, data, 2)
|
||||||
|
|
||||||
|
// 验证统一模型 ID 格式
|
||||||
|
first := data[0].(map[string]interface{})
|
||||||
|
assert.Equal(t, "openai/gpt-4", first["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
providerSvc := &mockProxyProviderService{
|
||||||
|
modelByProvName: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, &mockProxyRoutingService{}, providerSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}}
|
||||||
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "openai/gpt-4", resp["id"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
providerSvc := &mockProxyProviderService{
|
||||||
|
providers: []domain.Provider{
|
||||||
|
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &mockProxyProviderClient{
|
||||||
|
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
|
return &conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: []byte(`{"object":"list","data":[]}`),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, client, &mockProxyRoutingService{err: appErrors.ErrModelNotFound}, providerSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}}
|
||||||
|
c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Smart Passthrough 统一模型 ID 路由测试 ============
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := &mockProxyRoutingService{
|
||||||
|
result: &domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &mockProxyProviderClient{
|
||||||
|
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
|
// 验证请求体中的 model 已被改写为上游模型名
|
||||||
|
var req map[string]interface{}
|
||||||
|
json.Unmarshal(spec.Body, &req)
|
||||||
|
assert.Equal(t, "gpt-4", req["model"])
|
||||||
|
|
||||||
|
return &conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: 200,
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}`),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||||
|
// 客户端发送统一模型 ID
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
// 验证响应中的 model 已被改写为统一模型 ID
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "openai_p/gpt-4", resp["model"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ 跨协议统一模型 ID 路由测试 ============
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := &mockProxyRoutingService{
|
||||||
|
result: &domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &mockProxyProviderClient{
|
||||||
|
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
|
return &conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: 200,
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
Body: []byte(`{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[{"type":"text","text":"Hello"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":3}}`),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||||
|
// OpenAI 客户端使用统一模型 ID 路由到 Anthropic 供应商
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
// 验证跨协议转换后响应中的 model 被覆写为统一模型 ID
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "anthropic_p/claude-3", resp["model"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := &mockProxyRoutingService{
|
||||||
|
result: &domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.com", Protocol: "anthropic", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := &mockProxyProviderClient{
|
||||||
|
sendStreamFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||||
|
ch := make(chan provider.StreamEvent, 10)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
ch <- provider.StreamEvent{Data: []byte(`event: message_start
|
||||||
|
data: {"type":"message_start","message":{"id":"msg-1","type":"message","role":"assistant","model":"claude-3","content":[]}}
|
||||||
|
|
||||||
|
`)}
|
||||||
|
ch <- provider.StreamEvent{Data: []byte(`event: content_block_delta
|
||||||
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}
|
||||||
|
|
||||||
|
`)}
|
||||||
|
ch <- provider.StreamEvent{Data: []byte(`event: message_stop
|
||||||
|
data: {"type":"message_stop"}
|
||||||
|
|
||||||
|
`)}
|
||||||
|
ch <- provider.StreamEvent{Done: true}
|
||||||
|
}()
|
||||||
|
return ch, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
// 验证跨协议流式中 model 被覆写为统一模型 ID
|
||||||
|
assert.Contains(t, body, "anthropic_p/claude-3", "跨协议流式响应中 model 应被覆写为统一模型 ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := &mockProxyRoutingService{
|
||||||
|
result: &domain.RouteResult{
|
||||||
|
Provider: &domain.Provider{ID: "openai_p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||||
|
Model: &domain.Model{ID: "m1", ProviderID: "openai_p", ModelName: "gpt-4", Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var capturedRequestBody []byte
|
||||||
|
client := &mockProxyProviderClient{
|
||||||
|
sendFn: func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||||
|
capturedRequestBody = spec.Body
|
||||||
|
return &conversion.HTTPResponseSpec{
|
||||||
|
StatusCode: 200,
|
||||||
|
Headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
Body: []byte(`{"id":"resp-1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8},"unknown_field":"preserved"}`),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, client, routingSvc, &mockProxyProviderService{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||||
|
// 包含未知参数,验证 Smart Passthrough 保真性
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
|
// 验证请求中 model 被改写为上游模型名,但未知参数保留
|
||||||
|
var reqBody map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(capturedRequestBody, &reqBody))
|
||||||
|
assert.Equal(t, "gpt-4", reqBody["model"], "请求中 model 应被改写为上游模型名")
|
||||||
|
assert.Equal(t, "should_be_preserved", reqBody["custom_param"], "Smart Passthrough 应保留未知参数")
|
||||||
|
|
||||||
|
// 验证响应中 model 被改写为统一模型 ID,但未知参数保留
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Equal(t, "openai_p/gpt-4", resp["model"], "响应中 model 应被改写为统一模型 ID")
|
||||||
|
assert.Equal(t, "preserved", resp["unknown_field"], "Smart Passthrough 应保留未知响应字段")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
|
||||||
|
engine := setupProxyEngine(t)
|
||||||
|
routingSvc := &mockProxyRoutingService{err: appErrors.ErrModelNotFound}
|
||||||
|
providerSvc := &mockProxyProviderService{
|
||||||
|
providers: []domain.Provider{
|
||||||
|
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newTestProxyHandler(engine, &mockProxyProviderClient{}, routingSvc, providerSvc)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||||
|
// 使用统一模型 ID 格式但模型不存在
|
||||||
|
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
||||||
|
|
||||||
|
h.HandleProxy(c)
|
||||||
|
assert.Equal(t, 404, w.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
assert.Contains(t, resp, "error")
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ type ModelRepository interface {
|
|||||||
Create(model *domain.Model) error
|
Create(model *domain.Model) error
|
||||||
GetByID(id string) (*domain.Model, error)
|
GetByID(id string) (*domain.Model, error)
|
||||||
List(providerID string) ([]domain.Model, error)
|
List(providerID string) ([]domain.Model, error)
|
||||||
GetByModelName(modelName string) (*domain.Model, error)
|
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
|
||||||
|
ListEnabled() ([]domain.Model, error)
|
||||||
Update(id string, updates map[string]interface{}) error
|
Update(id string, updates map[string]interface{}) error
|
||||||
Delete(id string) error
|
Delete(id string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,9 +52,9 @@ func (r *modelRepository) List(providerID string) ([]domain.Model, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) {
|
func (r *modelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||||
var m config.Model
|
var m config.Model
|
||||||
err := r.db.Where("model_name = ?", modelName).First(&m).Error
|
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&m).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -62,6 +62,21 @@ func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error
|
|||||||
return &d, nil
|
return &d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *modelRepository) ListEnabled() ([]domain.Model, error) {
|
||||||
|
var models []config.Model
|
||||||
|
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
|
||||||
|
Where("models.enabled = ? AND providers.enabled = ?", true, true).
|
||||||
|
Find(&models).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := make([]domain.Model, len(models))
|
||||||
|
for i := range models {
|
||||||
|
result[i] = toDomainModel(&models[i])
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
|
func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
|
||||||
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
|
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
|
|||||||
@@ -9,4 +9,7 @@ type ProviderRepository interface {
|
|||||||
List() ([]domain.Provider, error)
|
List() ([]domain.Provider, error)
|
||||||
Update(id string, updates map[string]interface{}) error
|
Update(id string, updates map[string]interface{}) error
|
||||||
Delete(id string) error
|
Delete(id string) error
|
||||||
|
// 统一模型 ID 相关方法
|
||||||
|
ListEnabledModels() ([]domain.Model, error)
|
||||||
|
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,6 +71,25 @@ func (r *providerRepository) Delete(id string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListEnabledModels 返回所有启用的模型(关联启用的供应商)
|
||||||
|
func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) {
|
||||||
|
var models []domain.Model
|
||||||
|
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
|
||||||
|
Where("models.enabled = ? AND providers.enabled = ?", true, true).
|
||||||
|
Find(&models).Error
|
||||||
|
return models, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型
|
||||||
|
func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||||
|
var model domain.Model
|
||||||
|
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &model, nil
|
||||||
|
}
|
||||||
|
|
||||||
func toDomainProvider(p *config.Provider) domain.Provider {
|
func toDomainProvider(p *config.Provider) domain.Provider {
|
||||||
return domain.Provider{
|
return domain.Provider{
|
||||||
ID: p.ID,
|
ID: p.ID,
|
||||||
|
|||||||
@@ -147,15 +147,36 @@ func TestModelRepository_GetByID(t *testing.T) {
|
|||||||
assert.Equal(t, "gpt-4", result.ModelName)
|
assert.Equal(t, "gpt-4", result.ModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelRepository_GetByModelName(t *testing.T) {
|
func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
repo := NewModelRepository(db)
|
repo := NewModelRepository(db)
|
||||||
|
|
||||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||||
|
|
||||||
result, err := repo.GetByModelName("gpt-4")
|
result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "m1", result.ID)
|
assert.Equal(t, "m1", result.ID)
|
||||||
|
assert.Equal(t, "p1", result.ProviderID)
|
||||||
|
assert.Equal(t, "gpt-4", result.ModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
repo := NewModelRepository(db)
|
||||||
|
|
||||||
|
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||||
|
|
||||||
|
// Wrong provider_id
|
||||||
|
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Wrong model_name
|
||||||
|
_, err = repo.FindByProviderAndModelName("p1", "gpt-3.5")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Both wrong
|
||||||
|
_, err = repo.FindByProviderAndModelName("p2", "claude-3")
|
||||||
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelRepository_List(t *testing.T) {
|
func TestModelRepository_List(t *testing.T) {
|
||||||
@@ -175,6 +196,54 @@ func TestModelRepository_List(t *testing.T) {
|
|||||||
assert.Len(t, p1Models, 2)
|
assert.Len(t, p1Models, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModelRepository_ListEnabled(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
providerRepo := NewProviderRepository(db)
|
||||||
|
modelRepo := NewModelRepository(db)
|
||||||
|
|
||||||
|
// Create two providers (both start enabled due to gorm:"default:true")
|
||||||
|
err := providerRepo.Create(&domain.Provider{
|
||||||
|
ID: "enabled-provider", Name: "Enabled Provider",
|
||||||
|
APIKey: "key1", BaseURL: "https://enabled.com", Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = providerRepo.Create(&domain.Provider{
|
||||||
|
ID: "disabled-provider", Name: "Disabled Provider",
|
||||||
|
APIKey: "key2", BaseURL: "https://disabled.com", Enabled: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Disable the second provider via Update (GORM default:true skips zero values on Create)
|
||||||
|
err = providerRepo.Update("disabled-provider", map[string]interface{}{"enabled": false})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create models (all start enabled due to gorm:"default:true")
|
||||||
|
err = modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "enabled-provider", ModelName: "gpt-4", Enabled: true})
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = modelRepo.Create(&domain.Model{ID: "m2", ProviderID: "enabled-provider", ModelName: "gpt-3.5", Enabled: true})
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = modelRepo.Create(&domain.Model{ID: "m3", ProviderID: "disabled-provider", ModelName: "claude-3", Enabled: true})
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = modelRepo.Create(&domain.Model{ID: "m4", ProviderID: "disabled-provider", ModelName: "claude-3.5", Enabled: true})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Disable m2 via Update
|
||||||
|
err = modelRepo.Update("m2", map[string]interface{}{"enabled": false})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// ListEnabled should only return models where both model and provider are enabled:
|
||||||
|
// - m1: enabled provider + enabled model -> returned
|
||||||
|
// - m2: enabled provider + disabled model -> filtered out
|
||||||
|
// - m3: disabled provider + enabled model -> filtered out
|
||||||
|
// - m4: disabled provider + enabled model -> filtered out
|
||||||
|
enabled, err := modelRepo.ListEnabled()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, enabled, 1)
|
||||||
|
assert.Equal(t, "m1", enabled[0].ID)
|
||||||
|
assert.Equal(t, "enabled-provider", enabled[0].ProviderID)
|
||||||
|
assert.Equal(t, "gpt-4", enabled[0].ModelName)
|
||||||
|
}
|
||||||
|
|
||||||
func TestModelRepository_Update(t *testing.T) {
|
func TestModelRepository_Update(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
repo := NewModelRepository(db)
|
repo := NewModelRepository(db)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ type ModelService interface {
|
|||||||
Create(model *domain.Model) error
|
Create(model *domain.Model) error
|
||||||
Get(id string) (*domain.Model, error)
|
Get(id string) (*domain.Model, error)
|
||||||
List(providerID string) ([]domain.Model, error)
|
List(providerID string) ([]domain.Model, error)
|
||||||
|
ListEnabled() ([]domain.Model, error)
|
||||||
Update(id string, updates map[string]interface{}) error
|
Update(id string, updates map[string]interface{}) error
|
||||||
Delete(id string) error
|
Delete(id string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/google/uuid"
|
||||||
appErrors "nex/backend/pkg/errors"
|
appErrors "nex/backend/pkg/errors"
|
||||||
|
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
@@ -17,11 +18,18 @@ func NewModelService(modelRepo repository.ModelRepository, providerRepo reposito
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *modelService) Create(model *domain.Model) error {
|
func (s *modelService) Create(model *domain.Model) error {
|
||||||
// Verify provider exists
|
// 校验供应商存在
|
||||||
_, err := s.providerRepo.GetByID(model.ProviderID)
|
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
|
||||||
if err != nil {
|
|
||||||
return appErrors.ErrProviderNotFound
|
return appErrors.ErrProviderNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 联合唯一校验:同一供应商下 model_name 不重复
|
||||||
|
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动生成 UUID 作为 id
|
||||||
|
model.ID = uuid.New().String()
|
||||||
model.Enabled = true
|
model.Enabled = true
|
||||||
return s.modelRepo.Create(model)
|
return s.modelRepo.Create(model)
|
||||||
}
|
}
|
||||||
@@ -34,17 +42,57 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) {
|
|||||||
return s.modelRepo.List(providerID)
|
return s.modelRepo.List(providerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *modelService) ListEnabled() ([]domain.Model, error) {
|
||||||
|
return s.modelRepo.ListEnabled()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||||
// If updating provider_id, verify new provider exists
|
// 获取当前模型
|
||||||
|
current, err := s.modelRepo.GetByID(id)
|
||||||
|
if err != nil {
|
||||||
|
return appErrors.ErrModelNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果更新 provider_id,校验新供应商存在
|
||||||
if providerID, ok := updates["provider_id"].(string); ok {
|
if providerID, ok := updates["provider_id"].(string); ok {
|
||||||
_, err := s.providerRepo.GetByID(providerID)
|
if _, err := s.providerRepo.GetByID(providerID); err != nil {
|
||||||
if err != nil {
|
|
||||||
return appErrors.ErrProviderNotFound
|
return appErrors.ErrProviderNotFound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 确定更新后的 provider_id 和 model_name
|
||||||
|
newProviderID := current.ProviderID
|
||||||
|
if v, ok := updates["provider_id"].(string); ok {
|
||||||
|
newProviderID = v
|
||||||
|
}
|
||||||
|
newModelName := current.ModelName
|
||||||
|
if v, ok := updates["model_name"].(string); ok {
|
||||||
|
newModelName = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
|
||||||
|
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||||
|
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return s.modelRepo.Update(id, updates)
|
return s.modelRepo.Update(id, updates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *modelService) Delete(id string) error {
|
func (s *modelService) Delete(id string) error {
|
||||||
return s.modelRepo.Delete(id)
|
return s.modelRepo.Delete(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
|
||||||
|
// excludeID 用于更新时排除自身
|
||||||
|
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
|
||||||
|
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil // 未找到,不重复
|
||||||
|
}
|
||||||
|
if excludeID != "" && existing.ID == excludeID {
|
||||||
|
return nil // 排除自身
|
||||||
|
}
|
||||||
|
return appErrors.ErrDuplicateModel
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,4 +9,7 @@ type ProviderService interface {
|
|||||||
List() ([]domain.Provider, error)
|
List() ([]domain.Provider, error)
|
||||||
Update(id string, updates map[string]interface{}) error
|
Update(id string, updates map[string]interface{}) error
|
||||||
Delete(id string) error
|
Delete(id string) error
|
||||||
|
// 统一模型 ID 相关方法
|
||||||
|
ListEnabledModels() ([]domain.Model, error)
|
||||||
|
GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,35 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"nex/backend/pkg/modelid"
|
||||||
|
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
"nex/backend/internal/repository"
|
"nex/backend/internal/repository"
|
||||||
|
appErrors "nex/backend/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type providerService struct {
|
type providerService struct {
|
||||||
providerRepo repository.ProviderRepository
|
providerRepo repository.ProviderRepository
|
||||||
|
modelRepo repository.ModelRepository
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
|
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
|
||||||
return &providerService{providerRepo: providerRepo}
|
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *providerService) Create(provider *domain.Provider) error {
|
func (s *providerService) Create(provider *domain.Provider) error {
|
||||||
|
// 校验 provider_id 字符集
|
||||||
|
if err := modelid.ValidateProviderID(provider.ID); err != nil {
|
||||||
|
return appErrors.ErrInvalidProviderID
|
||||||
|
}
|
||||||
provider.Enabled = true
|
provider.Enabled = true
|
||||||
return s.providerRepo.Create(provider)
|
err := s.providerRepo.Create(provider)
|
||||||
|
if err != nil && isUniqueConstraintError(err) {
|
||||||
|
return appErrors.ErrConflict
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||||
@@ -41,9 +55,31 @@ func (s *providerService) List() ([]domain.Provider, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *providerService) Update(id string, updates map[string]interface{}) error {
|
func (s *providerService) Update(id string, updates map[string]interface{}) error {
|
||||||
|
if _, ok := updates["id"]; ok {
|
||||||
|
return appErrors.ErrImmutableField
|
||||||
|
}
|
||||||
return s.providerRepo.Update(id, updates)
|
return s.providerRepo.Update(id, updates)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *providerService) Delete(id string) error {
|
func (s *providerService) Delete(id string) error {
|
||||||
return s.providerRepo.Delete(id)
|
return s.providerRepo.Delete(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
|
||||||
|
func (s *providerService) ListEnabledModels() ([]domain.Model, error) {
|
||||||
|
return s.modelRepo.ListEnabled()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询)
|
||||||
|
func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||||
|
return s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误
|
||||||
|
func isUniqueConstraintError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate")
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,5 +4,5 @@ import "nex/backend/internal/domain"
|
|||||||
|
|
||||||
// RoutingService 路由服务接口
|
// RoutingService 路由服务接口
|
||||||
type RoutingService interface {
|
type RoutingService interface {
|
||||||
Route(modelName string) (*domain.RouteResult, error)
|
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ func NewRoutingService(modelRepo repository.ModelRepository, providerRepo reposi
|
|||||||
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
|
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
|
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||||
model, err := s.modelRepo.GetByModelName(modelName)
|
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, appErrors.ErrModelNotFound
|
return nil, appErrors.ErrModelNotFound
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ import (
|
|||||||
func TestProviderService_Update(t *testing.T) {
|
func TestProviderService_Update(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
repo := repository.NewProviderRepository(db)
|
repo := repository.NewProviderRepository(db)
|
||||||
svc := NewProviderService(repo)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewProviderService(repo, modelRepo)
|
||||||
|
|
||||||
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
|
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
|
||||||
|
|
||||||
@@ -28,7 +29,8 @@ func TestProviderService_Update(t *testing.T) {
|
|||||||
func TestProviderService_Update_NotFound(t *testing.T) {
|
func TestProviderService_Update_NotFound(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
repo := repository.NewProviderRepository(db)
|
repo := repository.NewProviderRepository(db)
|
||||||
svc := NewProviderService(repo)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewProviderService(repo, modelRepo)
|
||||||
|
|
||||||
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
|
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -41,11 +43,12 @@ func TestModelService_Get(t *testing.T) {
|
|||||||
svc := NewModelService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||||
|
require.NoError(t, svc.Create(model))
|
||||||
|
|
||||||
model, err := svc.Get("m1")
|
result, err := svc.Get(model.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "gpt-4", model.ModelName)
|
assert.Equal(t, "gpt-4", result.ModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelService_Update(t *testing.T) {
|
func TestModelService_Update(t *testing.T) {
|
||||||
@@ -55,14 +58,15 @@ func TestModelService_Update(t *testing.T) {
|
|||||||
svc := NewModelService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||||
|
require.NoError(t, svc.Create(model))
|
||||||
|
|
||||||
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"})
|
err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
model, err := svc.Get("m1")
|
result, err := svc.Get(model.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "gpt-4o", model.ModelName)
|
assert.Equal(t, "gpt-4o", result.ModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
||||||
@@ -72,9 +76,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
|||||||
svc := NewModelService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||||
|
require.NoError(t, svc.Create(model))
|
||||||
|
|
||||||
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"})
|
err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,12 +90,13 @@ func TestModelService_Delete(t *testing.T) {
|
|||||||
svc := NewModelService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||||
|
require.NoError(t, svc.Create(model))
|
||||||
|
|
||||||
err := svc.Delete("m1")
|
err := svc.Delete(model.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = svc.Get("m1")
|
_, err = svc.Get(model.ID)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
@@ -11,6 +13,7 @@ import (
|
|||||||
"nex/backend/internal/config"
|
"nex/backend/internal/config"
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
"nex/backend/internal/repository"
|
"nex/backend/internal/repository"
|
||||||
|
appErrors "nex/backend/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||||
@@ -29,80 +32,106 @@ func setupServiceTestDB(t *testing.T) *gorm.DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ ProviderService 测试 ============
|
// ============ RoutingService - RouteByModelName 测试 ============
|
||||||
|
|
||||||
func TestProviderService_Create(t *testing.T) {
|
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
repo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
svc := NewProviderService(repo)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewRoutingService(modelRepo, providerRepo)
|
||||||
|
|
||||||
provider := &domain.Provider{
|
// 创建供应商和模型
|
||||||
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||||
}
|
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||||
err := svc.Create(provider)
|
|
||||||
|
result, err := svc.RouteByModelName("openai", "gpt-4")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, provider.Enabled)
|
assert.Equal(t, "openai", result.Provider.ID)
|
||||||
|
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProviderService_Get_MaskKey(t *testing.T) {
|
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
repo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
svc := NewProviderService(repo)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewRoutingService(modelRepo, providerRepo)
|
||||||
|
|
||||||
svc.Create(&domain.Provider{
|
_, err := svc.RouteByModelName("openai", "nonexistent-model")
|
||||||
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
|
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||||
})
|
|
||||||
|
|
||||||
result, err := svc.Get("p1", true)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "***2345", result.APIKey)
|
|
||||||
|
|
||||||
result, err = svc.Get("p1", false)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProviderService_List(t *testing.T) {
|
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
repo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
svc := NewProviderService(repo)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewRoutingService(modelRepo, providerRepo)
|
||||||
|
|
||||||
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
|
// 创建启用的供应商和禁用的模型
|
||||||
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||||
|
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||||
|
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||||
|
|
||||||
providers, err := svc.List()
|
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||||
require.NoError(t, err)
|
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
||||||
assert.Len(t, providers, 2)
|
|
||||||
assert.Contains(t, providers[0].APIKey, "***")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProviderService_Delete(t *testing.T) {
|
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
repo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
svc := NewProviderService(repo)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewRoutingService(modelRepo, providerRepo)
|
||||||
|
|
||||||
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
// 创建启用的供应商和模型,然后禁用供应商
|
||||||
err := svc.Delete("p1")
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||||
require.NoError(t, err)
|
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||||
|
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
|
||||||
|
|
||||||
_, err = svc.Get("p1", false)
|
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||||
assert.Error(t, err)
|
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ ModelService 测试 ============
|
// ============ ModelService - Create with UUID 测试 ============
|
||||||
|
|
||||||
func TestModelService_Create(t *testing.T) {
|
func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
svc := NewModelService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||||
|
|
||||||
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||||
err := svc.Create(model)
|
err := svc.Create(model)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, model.Enabled)
|
|
||||||
|
// 验证返回的 model 拥有有效的 UUID
|
||||||
|
assert.NotEmpty(t, model.ID)
|
||||||
|
_, err = uuid.Parse(model.ID)
|
||||||
|
assert.NoError(t, err, "model.ID should be a valid UUID")
|
||||||
|
|
||||||
|
// 通过 Get 验证持久化
|
||||||
|
stored, err := svc.Get(model.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, model.ID, stored.ID)
|
||||||
|
assert.Equal(t, "gpt-4", stored.ModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||||
|
db := setupServiceTestDB(t)
|
||||||
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||||
|
|
||||||
|
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||||
|
err := svc.Create(model1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
|
||||||
|
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||||
|
err = svc.Create(model2)
|
||||||
|
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||||
@@ -111,160 +140,135 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
|||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
svc := NewModelService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
|
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||||
err := svc.Create(model)
|
err := svc.Create(model)
|
||||||
assert.Error(t, err)
|
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelService_List(t *testing.T) {
|
// ============ ProviderService - Create with validation 测试 ============
|
||||||
|
|
||||||
|
func TestProviderService_Create_InvalidID(t *testing.T) {
|
||||||
|
db := setupServiceTestDB(t)
|
||||||
|
repo := repository.NewProviderRepository(db)
|
||||||
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewProviderService(repo, modelRepo)
|
||||||
|
|
||||||
|
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||||
|
err := svc.Create(provider)
|
||||||
|
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderService_Create_ValidID(t *testing.T) {
|
||||||
|
db := setupServiceTestDB(t)
|
||||||
|
repo := repository.NewProviderRepository(db)
|
||||||
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewProviderService(repo, modelRepo)
|
||||||
|
|
||||||
|
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||||
|
err := svc.Create(provider)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai", provider.ID)
|
||||||
|
assert.True(t, provider.Enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ ModelService - Update with duplicate check 测试 ============
|
||||||
|
|
||||||
|
func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
svc := NewModelService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
|
||||||
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
|
||||||
|
|
||||||
models, err := svc.List("p1")
|
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||||
|
err := svc.Create(model1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, models, 2)
|
|
||||||
|
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
|
||||||
|
err = svc.Create(model2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
|
||||||
|
err = svc.Update(model2.ID, map[string]interface{}{
|
||||||
|
"provider_id": "openai",
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
})
|
||||||
|
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ RoutingService 测试 ============
|
func TestModelService_Update_ModelNotFound(t *testing.T) {
|
||||||
|
|
||||||
func TestRoutingService_Route(t *testing.T) {
|
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
svc := NewRoutingService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
err := svc.Update("nonexistent-id", map[string]interface{}{
|
||||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
"model_name": "gpt-4",
|
||||||
|
})
|
||||||
result, err := svc.Route("gpt-4")
|
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "p1", result.Provider.ID)
|
|
||||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
|
func TestModelService_Update_Success(t *testing.T) {
|
||||||
db := setupServiceTestDB(t)
|
db := setupServiceTestDB(t)
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
svc := NewRoutingService(modelRepo, providerRepo)
|
svc := NewModelService(modelRepo, providerRepo)
|
||||||
|
|
||||||
_, err := svc.Route("nonexistent-model")
|
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
|
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||||
db := setupServiceTestDB(t)
|
err := svc.Create(model)
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
|
||||||
modelRepo := repository.NewModelRepository(db)
|
|
||||||
svc := NewRoutingService(modelRepo, providerRepo)
|
|
||||||
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
|
||||||
// 先创建启用的模型,然后通过 Update 禁用
|
|
||||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
|
||||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
|
||||||
|
|
||||||
_, err := svc.Route("gpt-4")
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
|
|
||||||
db := setupServiceTestDB(t)
|
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
|
||||||
modelRepo := repository.NewModelRepository(db)
|
|
||||||
svc := NewRoutingService(modelRepo, providerRepo)
|
|
||||||
|
|
||||||
// 先创建启用的 provider,然后禁用
|
|
||||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
|
||||||
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
|
|
||||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
|
||||||
|
|
||||||
_, err := svc.Route("gpt-4")
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ StatsService 测试 ============
|
|
||||||
|
|
||||||
func TestStatsService_RecordAndGet(t *testing.T) {
|
|
||||||
db := setupServiceTestDB(t)
|
|
||||||
statsRepo := repository.NewStatsRepository(db)
|
|
||||||
svc := NewStatsService(statsRepo)
|
|
||||||
|
|
||||||
err := svc.Record("p1", "gpt-4")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
stats, err := svc.Get("p1", "", nil, nil)
|
// 更新 model_name 为不冲突的值
|
||||||
|
err = svc.Update(model.ID, map[string]interface{}{
|
||||||
|
"model_name": "gpt-4-turbo",
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, stats, 1)
|
|
||||||
|
updated, err := svc.Get(model.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
|
// ============ ProviderService - Update immutable ID 测试 ============
|
||||||
statsRepo := repository.NewStatsRepository(nil)
|
|
||||||
svc := NewStatsService(statsRepo)
|
|
||||||
|
|
||||||
stats := []domain.UsageStats{
|
func TestProviderService_Update_ImmutableID(t *testing.T) {
|
||||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
db := setupServiceTestDB(t)
|
||||||
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
|
repo := repository.NewProviderRepository(db)
|
||||||
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
|
modelRepo := repository.NewModelRepository(db)
|
||||||
}
|
svc := NewProviderService(repo, modelRepo)
|
||||||
|
|
||||||
result := svc.Aggregate(stats, "provider")
|
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||||
assert.Len(t, result, 2)
|
err := svc.Create(provider)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
p1Count := 0
|
// 尝试更新 id 字段
|
||||||
p2Count := 0
|
err = svc.Update("openai", map[string]interface{}{
|
||||||
for _, r := range result {
|
"id": "new-id",
|
||||||
if r["provider_id"] == "p1" {
|
})
|
||||||
p1Count = r["request_count"].(int)
|
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
|
||||||
}
|
|
||||||
if r["provider_id"] == "p2" {
|
|
||||||
p2Count = r["request_count"].(int)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.Equal(t, 15, p1Count)
|
|
||||||
assert.Equal(t, 8, p2Count)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
func TestProviderService_Update_Success(t *testing.T) {
|
||||||
statsRepo := repository.NewStatsRepository(nil)
|
db := setupServiceTestDB(t)
|
||||||
svc := NewStatsService(statsRepo)
|
repo := repository.NewProviderRepository(db)
|
||||||
|
modelRepo := repository.NewModelRepository(db)
|
||||||
|
svc := NewProviderService(repo, modelRepo)
|
||||||
|
|
||||||
stats := []domain.UsageStats{
|
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||||
{ProviderID: "p1", RequestCount: 10},
|
err := svc.Create(provider)
|
||||||
{ProviderID: "p2", RequestCount: 5},
|
require.NoError(t, err)
|
||||||
}
|
|
||||||
|
|
||||||
result := svc.Aggregate(stats, "date")
|
// 更新 name
|
||||||
assert.Len(t, result, 1)
|
err = svc.Update("openai", map[string]interface{}{
|
||||||
assert.Equal(t, 15, result[0]["request_count"])
|
"name": "OpenAI Updated",
|
||||||
}
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
|
||||||
statsRepo := repository.NewStatsRepository(nil)
|
updated, err := svc.Get("openai", false)
|
||||||
svc := NewStatsService(statsRepo)
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "OpenAI Updated", updated.Name)
|
||||||
stats := []domain.UsageStats{
|
|
||||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
|
||||||
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
|
|
||||||
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
|
|
||||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
|
|
||||||
}
|
|
||||||
|
|
||||||
result := svc.Aggregate(stats, "model")
|
|
||||||
assert.Len(t, result, 3)
|
|
||||||
|
|
||||||
// 验证每个 provider/model 组合的计数
|
|
||||||
counts := make(map[string]int)
|
|
||||||
for _, r := range result {
|
|
||||||
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
|
|
||||||
counts[key] = r["request_count"].(int)
|
|
||||||
}
|
|
||||||
assert.Equal(t, 13, counts["openai/gpt-4"])
|
|
||||||
assert.Equal(t, 5, counts["openai/gpt-3.5"])
|
|
||||||
assert.Equal(t, 8, counts["anthropic/claude-3"])
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
-- +goose Up
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date);
|
|
||||||
|
|
||||||
-- +goose Down
|
|
||||||
DROP INDEX IF EXISTS idx_usage_stats_provider_model_date;
|
|
||||||
DROP INDEX IF EXISTS idx_models_model_name;
|
|
||||||
DROP INDEX IF EXISTS idx_models_provider_id;
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
-- +goose Up
|
|
||||||
ALTER TABLE providers ADD COLUMN protocol TEXT DEFAULT 'openai';
|
|
||||||
|
|
||||||
-- +goose Down
|
|
||||||
-- SQLite 不支持 DROP COLUMN(3.35.0 之前),但 goose 的 Down 通常不需要
|
|
||||||
CREATE TABLE providers_backup AS SELECT id, name, api_key, base_url, enabled, created_at, updated_at FROM providers;
|
|
||||||
@@ -1,9 +1,13 @@
|
|||||||
-- +goose Up
|
-- +goose Up
|
||||||
|
-- 统一初始迁移:providers、models、usage_stats 完整表结构
|
||||||
|
-- models 表使用 UUID 主键 + UNIQUE(provider_id, model_name) 联合唯一约束
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS providers (
|
CREATE TABLE IF NOT EXISTS providers (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
api_key TEXT NOT NULL,
|
api_key TEXT NOT NULL,
|
||||||
base_url TEXT NOT NULL,
|
base_url TEXT NOT NULL,
|
||||||
|
protocol TEXT DEFAULT 'openai',
|
||||||
enabled INTEGER DEFAULT 1,
|
enabled INTEGER DEFAULT 1,
|
||||||
created_at DATETIME,
|
created_at DATETIME,
|
||||||
updated_at DATETIME
|
updated_at DATETIME
|
||||||
@@ -15,7 +19,8 @@ CREATE TABLE IF NOT EXISTS models (
|
|||||||
model_name TEXT NOT NULL,
|
model_name TEXT NOT NULL,
|
||||||
enabled INTEGER DEFAULT 1,
|
enabled INTEGER DEFAULT 1,
|
||||||
created_at DATETIME,
|
created_at DATETIME,
|
||||||
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE(provider_id, model_name)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS usage_stats (
|
CREATE TABLE IF NOT EXISTS usage_stats (
|
||||||
@@ -27,7 +32,14 @@ CREATE TABLE IF NOT EXISTS usage_stats (
|
|||||||
UNIQUE(provider_id, model_name, date)
|
UNIQUE(provider_id, model_name, date)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_models_provider_id ON models(provider_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_models_model_name ON models(model_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_usage_stats_provider_model_date ON usage_stats(provider_id, model_name, date);
|
||||||
|
|
||||||
-- +goose Down
|
-- +goose Down
|
||||||
|
DROP INDEX IF EXISTS idx_usage_stats_provider_model_date;
|
||||||
|
DROP INDEX IF EXISTS idx_models_model_name;
|
||||||
|
DROP INDEX IF EXISTS idx_models_provider_id;
|
||||||
DROP TABLE IF EXISTS usage_stats;
|
DROP TABLE IF EXISTS usage_stats;
|
||||||
DROP TABLE IF EXISTS models;
|
DROP TABLE IF EXISTS models;
|
||||||
DROP TABLE IF EXISTS providers;
|
DROP TABLE IF EXISTS providers;
|
||||||
@@ -49,17 +49,20 @@ func NewAppError(code, message string, httpStatus int) *AppError {
|
|||||||
|
|
||||||
// Predefined errors
|
// Predefined errors
|
||||||
var (
|
var (
|
||||||
ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound)
|
ErrModelNotFound = NewAppError("model_not_found", "模型未找到", http.StatusNotFound)
|
||||||
ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound)
|
ErrModelDisabled = NewAppError("model_disabled", "模型已禁用", http.StatusNotFound)
|
||||||
ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound)
|
ErrProviderNotFound = NewAppError("provider_not_found", "供应商未找到", http.StatusNotFound)
|
||||||
ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound)
|
ErrProviderDisabled = NewAppError("provider_disabled", "供应商已禁用", http.StatusNotFound)
|
||||||
ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest)
|
ErrInvalidRequest = NewAppError("invalid_request", "无效的请求", http.StatusBadRequest)
|
||||||
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
|
ErrInternal = NewAppError("internal_error", "内部错误", http.StatusInternalServerError)
|
||||||
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
|
ErrDatabaseNotInit = NewAppError("database_not_initialized", "数据库未初始化", http.StatusInternalServerError)
|
||||||
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
|
ErrConflict = NewAppError("conflict", "资源已存在", http.StatusConflict)
|
||||||
ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError)
|
ErrRequestCreate = NewAppError("request_create_error", "创建请求失败", http.StatusInternalServerError)
|
||||||
ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway)
|
ErrRequestSend = NewAppError("request_send_error", "发送请求失败", http.StatusBadGateway)
|
||||||
ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway)
|
ErrResponseRead = NewAppError("response_read_error", "读取响应失败", http.StatusBadGateway)
|
||||||
|
ErrInvalidProviderID = NewAppError("invalid_provider_id", "供应商 ID 仅允许字母、数字、下划线,长度 1-64", http.StatusBadRequest)
|
||||||
|
ErrDuplicateModel = NewAppError("duplicate_model", "同一供应商下模型名称已存在", http.StatusConflict)
|
||||||
|
ErrImmutableField = NewAppError("immutable_field", "供应商 ID 不允许修改", http.StatusBadRequest)
|
||||||
)
|
)
|
||||||
|
|
||||||
// AsAppError 尝试将 error 转换为 *AppError
|
// AsAppError 尝试将 error 转换为 *AppError
|
||||||
|
|||||||
@@ -9,6 +9,14 @@ import (
|
|||||||
"go.uber.org/zap/zapcore"
|
"go.uber.org/zap/zapcore"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// stdoutWriter 包装 os.Stdout,忽略 Sync() 错误。
|
||||||
|
// 在非 TTY 环境(如 go test)中,os.Stdout 被重定向为 pipe,
|
||||||
|
// 底层 fsync 会返回 "bad file descriptor"。zap 社区标准做法。
|
||||||
|
type stdoutWriter struct{}
|
||||||
|
|
||||||
|
func (stdoutWriter) Write(p []byte) (int, error) { return os.Stdout.Write(p) }
|
||||||
|
func (stdoutWriter) Sync() error { return nil }
|
||||||
|
|
||||||
// Config 日志配置
|
// Config 日志配置
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Level string // 日志级别: debug, info, warn, error
|
Level string // 日志级别: debug, info, warn, error
|
||||||
@@ -46,7 +54,7 @@ func New(cfg Config) (*zap.Logger, error) {
|
|||||||
|
|
||||||
stdoutCore := zapcore.NewCore(
|
stdoutCore := zapcore.NewCore(
|
||||||
stdoutEncoder,
|
stdoutEncoder,
|
||||||
zapcore.AddSync(os.Stdout),
|
zapcore.AddSync(stdoutWriter{}),
|
||||||
level,
|
level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
63
backend/pkg/modelid/model_id.go
Normal file
63
backend/pkg/modelid/model_id.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package modelid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var providerIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||||
|
|
||||||
|
// ParseUnifiedModelID 将 "provider_id/model_name" 格式的字符串解析为 providerID 和 modelName
|
||||||
|
// 在第一个 "/" 处分割,model_name 可以包含 "/"
|
||||||
|
func ParseUnifiedModelID(id string) (providerID, modelName string, err error) {
|
||||||
|
if id == "" {
|
||||||
|
return "", "", errors.New("统一模型 ID 不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.SplitN(id, "/", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", "", errors.New("统一模型 ID 格式错误,缺少分隔符 \"/\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
providerID = parts[0]
|
||||||
|
modelName = parts[1]
|
||||||
|
|
||||||
|
if providerID == "" {
|
||||||
|
return "", "", errors.New("provider_id 不能为空")
|
||||||
|
}
|
||||||
|
if modelName == "" {
|
||||||
|
return "", "", errors.New("model_name 不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !providerIDRegex.MatchString(providerID) {
|
||||||
|
return "", "", errors.New("provider_id 仅允许字母、数字、下划线")
|
||||||
|
}
|
||||||
|
|
||||||
|
return providerID, modelName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatUnifiedModelID 将 providerID 和 modelName 组合格式化为统一模型 ID
|
||||||
|
func FormatUnifiedModelID(providerID, modelName string) string {
|
||||||
|
return providerID + "/" + modelName
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateProviderID 校验 providerID 仅包含字母、数字、下划线,长度 1-64
|
||||||
|
func ValidateProviderID(id string) error {
|
||||||
|
if id == "" {
|
||||||
|
return errors.New("provider_id 不能为空")
|
||||||
|
}
|
||||||
|
if len(id) > 64 {
|
||||||
|
return errors.New("provider_id 长度不能超过 64 个字符")
|
||||||
|
}
|
||||||
|
if !providerIDRegex.MatchString(id) {
|
||||||
|
return errors.New("provider_id 仅允许字母、数字、下划线")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidUnifiedModelID 判断字符串是否为合法的统一模型 ID 格式
|
||||||
|
func IsValidUnifiedModelID(id string) bool {
|
||||||
|
_, _, err := ParseUnifiedModelID(id)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
96
backend/pkg/modelid/model_id_test.go
Normal file
96
backend/pkg/modelid/model_id_test.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package modelid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseUnifiedModelID_StandardFormat(t *testing.T) {
|
||||||
|
providerID, modelName, err := ParseUnifiedModelID("openai/gpt-4")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "openai", providerID)
|
||||||
|
assert.Equal(t, "gpt-4", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUnifiedModelID_ModelNameWithSlashes(t *testing.T) {
|
||||||
|
providerID, modelName, err := ParseUnifiedModelID("azure/accounts/org-123/models/gpt-4")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "azure", providerID)
|
||||||
|
assert.Equal(t, "accounts/org-123/models/gpt-4", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUnifiedModelID_MissingSeparator(t *testing.T) {
|
||||||
|
_, _, err := ParseUnifiedModelID("gpt-4")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUnifiedModelID_EmptyString(t *testing.T) {
|
||||||
|
_, _, err := ParseUnifiedModelID("")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUnifiedModelID_OnlySeparator(t *testing.T) {
|
||||||
|
tests := []string{"/model", "provider/", "/"}
|
||||||
|
for _, tc := range tests {
|
||||||
|
_, _, err := ParseUnifiedModelID(tc)
|
||||||
|
assert.Error(t, err, "输入 %q 应返回错误", tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUnifiedModelID_InvalidProviderID(t *testing.T) {
|
||||||
|
tests := []string{
|
||||||
|
"open-ai/gpt-4",
|
||||||
|
"open.ai/gpt-4",
|
||||||
|
"供应商/gpt-4",
|
||||||
|
"open ai/gpt-4",
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
_, _, err := ParseUnifiedModelID(tc)
|
||||||
|
assert.Error(t, err, "providerID 含非法字符 %q 应返回错误", tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatUnifiedModelID(t *testing.T) {
|
||||||
|
assert.Equal(t, "openai/gpt-4", FormatUnifiedModelID("openai", "gpt-4"))
|
||||||
|
assert.Equal(t, "anthropic/claude-3-opus", FormatUnifiedModelID("anthropic", "claude-3-opus"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateProviderID_Valid(t *testing.T) {
|
||||||
|
validIDs := []string{"openai", "deep_seek", "provider01", "OpenAI"}
|
||||||
|
for _, id := range validIDs {
|
||||||
|
assert.NoError(t, ValidateProviderID(id), "%q 应校验通过", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateProviderID_InvalidChars(t *testing.T) {
|
||||||
|
invalidIDs := []string{"open-ai", "open.ai", "open ai", "供应商", "open/ai"}
|
||||||
|
for _, id := range invalidIDs {
|
||||||
|
assert.Error(t, ValidateProviderID(id), "%q 应校验失败", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateProviderID_Empty(t *testing.T) {
|
||||||
|
assert.Error(t, ValidateProviderID(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateProviderID_TooLong(t *testing.T) {
|
||||||
|
longID := strings.Repeat("a", 65)
|
||||||
|
assert.Error(t, ValidateProviderID(longID))
|
||||||
|
|
||||||
|
exactly64 := strings.Repeat("a", 64)
|
||||||
|
assert.NoError(t, ValidateProviderID(exactly64))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidUnifiedModelID(t *testing.T) {
|
||||||
|
assert.True(t, IsValidUnifiedModelID("openai/gpt-4"))
|
||||||
|
assert.True(t, IsValidUnifiedModelID("anthropic/claude-3-opus-20240229"))
|
||||||
|
assert.True(t, IsValidUnifiedModelID("azure/accounts/org/models/gpt-4"))
|
||||||
|
|
||||||
|
assert.False(t, IsValidUnifiedModelID(""))
|
||||||
|
assert.False(t, IsValidUnifiedModelID("gpt-4"))
|
||||||
|
assert.False(t, IsValidUnifiedModelID("open-ai/gpt-4"))
|
||||||
|
assert.False(t, IsValidUnifiedModelID("/model"))
|
||||||
|
assert.False(t, IsValidUnifiedModelID("provider/"))
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package tests
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"nex/backend/internal/config"
|
"nex/backend/internal/config"
|
||||||
|
|
||||||
@@ -11,26 +12,36 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetupTestDB initializes a temporary SQLite database with auto-migration.
|
// SetupTestDB initializes an in-memory SQLite database with auto-migration.
|
||||||
|
// Uses :memory: mode with MaxOpenConns(1) to ensure all operations share the
|
||||||
|
// same connection, avoiding "database is closed" errors from connection pool.
|
||||||
|
// Enables foreign key constraints for SQLite.
|
||||||
func SetupTestDB(t *testing.T) *gorm.DB {
|
func SetupTestDB(t *testing.T) *gorm.DB {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
dir := t.TempDir()
|
db, err := gorm.Open(sqlite.Open(":memory:?_foreign_keys=on"), &gorm.Config{})
|
||||||
dsn := dir + "/test.db"
|
|
||||||
|
|
||||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
|
||||||
assert.NoError(t, err, "failed to open test database")
|
assert.NoError(t, err, "failed to open test database")
|
||||||
|
|
||||||
|
// 限制为单连接,确保 :memory: 数据库不被连接池丢弃
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
assert.NoError(t, err, "failed to get underlying sql.DB")
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
sqlDB.SetConnMaxLifetime(0)
|
||||||
|
|
||||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||||
assert.NoError(t, err, "failed to auto-migrate test database")
|
assert.NoError(t, err, "failed to auto-migrate test database")
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupTestDB closes the database and removes the temp database file.
|
// CleanupTestDB closes the database after a brief delay to allow async
|
||||||
|
// goroutines (e.g. stats recording) to finish.
|
||||||
func CleanupTestDB(t *testing.T, db *gorm.DB) {
|
func CleanupTestDB(t *testing.T, db *gorm.DB) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
// 等待异步 goroutine(如 statsService.Record)完成
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
sqlDB, err := db.DB()
|
sqlDB, err := db.DB()
|
||||||
assert.NoError(t, err, "failed to get underlying sql.DB")
|
assert.NoError(t, err, "failed to get underlying sql.DB")
|
||||||
|
|
||||||
@@ -57,7 +68,8 @@ func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateTestModel creates a test model and returns it.
|
// CreateTestModel creates a test model and returns it.
|
||||||
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) config.Model {
|
// Does NOT assert on error - returns the model and error for caller to verify.
|
||||||
|
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) (config.Model, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
model := config.Model{
|
model := config.Model{
|
||||||
@@ -68,7 +80,5 @@ func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, mo
|
|||||||
}
|
}
|
||||||
|
|
||||||
err := db.Create(&model).Error
|
err := db.Create(&model).Error
|
||||||
assert.NoError(t, err, "failed to create test model")
|
return model, err
|
||||||
|
|
||||||
return model
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"nex/backend/internal/config"
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
"nex/backend/internal/conversion/anthropic"
|
"nex/backend/internal/conversion/anthropic"
|
||||||
openaiConv "nex/backend/internal/conversion/openai"
|
openaiConv "nex/backend/internal/conversion/openai"
|
||||||
@@ -43,11 +41,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
|
|||||||
w.Write([]byte(`{"error":"not mocked"}`))
|
w.Write([]byte(`{"error":"not mocked"}`))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
dir := t.TempDir()
|
db := setupTestDB(t)
|
||||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
sqlDB, _ := db.DB()
|
sqlDB, _ := db.DB()
|
||||||
if sqlDB != nil {
|
if sqlDB != nil {
|
||||||
@@ -60,7 +54,7 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
|
|||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
statsRepo := repository.NewStatsRepository(db)
|
statsRepo := repository.NewStatsRepository(db)
|
||||||
|
|
||||||
providerService := service.NewProviderService(providerRepo)
|
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||||
statsService := service.NewStatsService(statsRepo)
|
statsService := service.NewStatsService(statsRepo)
|
||||||
@@ -125,7 +119,7 @@ func createProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol, m
|
|||||||
require.Equal(t, 201, w.Code)
|
require.Equal(t, 201, w.Code)
|
||||||
|
|
||||||
modelBody, _ := json.Marshal(map[string]string{
|
modelBody, _ := json.Marshal(map[string]string{
|
||||||
"id": modelName,
|
|
||||||
"provider_id": providerID,
|
"provider_id": providerID,
|
||||||
"model_name": modelName,
|
"model_name": modelName,
|
||||||
})
|
})
|
||||||
@@ -156,7 +150,7 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
|
|||||||
"id": "msg_test",
|
"id": "msg_test",
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"model": "claude-3-opus",
|
"model": "anthropic_p/claude-3-opus",
|
||||||
"content": []map[string]any{
|
"content": []map[string]any{
|
||||||
{"type": "text", "text": "Hello from Anthropic!"},
|
{"type": "text", "text": "Hello from Anthropic!"},
|
||||||
},
|
},
|
||||||
@@ -170,11 +164,11 @@ func TestConversion_OpenAIToAnthropic_NonStream(t *testing.T) {
|
|||||||
json.NewEncoder(w).Encode(resp)
|
json.NewEncoder(w).Encode(resp)
|
||||||
})
|
})
|
||||||
|
|
||||||
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL)
|
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
|
||||||
|
|
||||||
// 使用 OpenAI 格式发送请求
|
// 使用 OpenAI 格式发送请求
|
||||||
openaiReq := map[string]any{
|
openaiReq := map[string]any{
|
||||||
"model": "claude-3-opus",
|
"model": "anthropic_p/claude-3-opus",
|
||||||
"messages": []map[string]any{
|
"messages": []map[string]any{
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
},
|
},
|
||||||
@@ -233,10 +227,10 @@ func TestConversion_AnthropicToOpenAI_NonStream(t *testing.T) {
|
|||||||
json.NewEncoder(w).Encode(resp)
|
json.NewEncoder(w).Encode(resp)
|
||||||
})
|
})
|
||||||
|
|
||||||
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||||
|
|
||||||
anthropicReq := map[string]any{
|
anthropicReq := map[string]any{
|
||||||
"model": "gpt-4",
|
"model": "openai_p/gpt-4",
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"messages": []map[string]any{
|
"messages": []map[string]any{
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
@@ -273,16 +267,18 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
|
|||||||
body, _ := io.ReadAll(r.Body)
|
body, _ := io.ReadAll(r.Body)
|
||||||
var req map[string]any
|
var req map[string]any
|
||||||
json.Unmarshal(body, &req)
|
json.Unmarshal(body, &req)
|
||||||
|
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
|
||||||
assert.Equal(t, "gpt-4", req["model"])
|
assert.Equal(t, "gpt-4", req["model"])
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
// 上游返回上游模型名
|
||||||
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
|
w.Write([]byte(`{"id":"chatcmpl-pass","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"passthrough"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}`))
|
||||||
})
|
})
|
||||||
|
|
||||||
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||||
|
|
||||||
reqBody := map[string]any{
|
reqBody := map[string]any{
|
||||||
"model": "gpt-4",
|
"model": "openai_p/gpt-4", // 客户端发送统一 ID
|
||||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||||
}
|
}
|
||||||
body, _ := json.Marshal(reqBody)
|
body, _ := json.Marshal(reqBody)
|
||||||
@@ -293,7 +289,8 @@ func TestConversion_OpenAIToOpenAI_Passthrough(t *testing.T) {
|
|||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "passthrough")
|
// Smart Passthrough: 响应体中的上游模型名应被改写为统一 ID
|
||||||
|
assert.Contains(t, w.Body.String(), `"model":"openai_p/gpt-4"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
|
func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
|
||||||
@@ -302,14 +299,21 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
|
|||||||
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, "/v1/messages", r.URL.Path)
|
assert.Equal(t, "/v1/messages", r.URL.Path)
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
var req map[string]any
|
||||||
|
json.Unmarshal(body, &req)
|
||||||
|
// Smart Passthrough: 请求体中的统一 ID 应被改写为上游模型名
|
||||||
|
assert.Equal(t, "claude-3-opus", req["model"])
|
||||||
|
|
||||||
|
// 上游返回上游模型名
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
|
w.Write([]byte(`{"id":"msg-pass","type":"message","role":"assistant","model":"claude-3-opus","content":[{"type":"text","text":"passthrough"}],"stop_reason":"end_turn","usage":{"input_tokens":5,"output_tokens":1}}`))
|
||||||
})
|
})
|
||||||
|
|
||||||
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL)
|
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
|
||||||
|
|
||||||
reqBody := map[string]any{
|
reqBody := map[string]any{
|
||||||
"model": "claude-3-opus",
|
"model": "anthropic_p/claude-3-opus", // 客户端发送统一 ID
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||||
}
|
}
|
||||||
@@ -321,7 +325,8 @@ func TestConversion_AnthropicToAnthropic_Passthrough(t *testing.T) {
|
|||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
assert.Contains(t, w.Body.String(), "passthrough")
|
// Smart Passthrough: 响应体中的上游模型名应被改写为统一 ID
|
||||||
|
assert.Contains(t, w.Body.String(), `"model":"anthropic_p/claude-3-opus"`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ 流式转换测试 ============
|
// ============ 流式转换测试 ============
|
||||||
@@ -349,10 +354,10 @@ func TestConversion_OpenAIToAnthropic_Stream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
createProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-3-opus", upstream.URL)
|
createProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-3-opus", upstream.URL)
|
||||||
|
|
||||||
openaiReq := map[string]any{
|
openaiReq := map[string]any{
|
||||||
"model": "claude-3-opus",
|
"model": "anthropic_p/claude-3-opus",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
}
|
}
|
||||||
@@ -390,10 +395,10 @@ func TestConversion_AnthropicToOpenAI_Stream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
createProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
createProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||||
|
|
||||||
anthropicReq := map[string]any{
|
anthropicReq := map[string]any{
|
||||||
"model": "gpt-4",
|
"model": "openai_p/gpt-4",
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
@@ -512,7 +517,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
|
|||||||
|
|
||||||
// 创建带 protocol 字段的 provider
|
// 创建带 protocol 字段的 provider
|
||||||
providerBody := map[string]any{
|
providerBody := map[string]any{
|
||||||
"id": "test-protocol",
|
"id": "test_protocol",
|
||||||
"name": "Test Protocol",
|
"name": "Test Protocol",
|
||||||
"api_key": "sk-test",
|
"api_key": "sk-test",
|
||||||
"base_url": "https://test.com",
|
"base_url": "https://test.com",
|
||||||
@@ -533,7 +538,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
|
|||||||
|
|
||||||
// 获取时应包含 protocol
|
// 获取时应包含 protocol
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest("GET", "/api/providers/test-protocol", nil)
|
req = httptest.NewRequest("GET", "/api/providers/test_protocol", nil)
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
|
|
||||||
@@ -547,7 +552,7 @@ func TestConversion_ProviderDefaultProtocol(t *testing.T) {
|
|||||||
|
|
||||||
// 不指定 protocol,默认应为 openai
|
// 不指定 protocol,默认应为 openai
|
||||||
providerBody := map[string]any{
|
providerBody := map[string]any{
|
||||||
"id": "default-proto",
|
"id": "default_proto",
|
||||||
"name": "Default",
|
"name": "Default",
|
||||||
"api_key": "sk-test",
|
"api_key": "sk-test",
|
||||||
"base_url": "https://test.com",
|
"base_url": "https://test.com",
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,10 +15,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
|
|
||||||
"nex/backend/internal/config"
|
|
||||||
"nex/backend/internal/conversion"
|
"nex/backend/internal/conversion"
|
||||||
"nex/backend/internal/conversion/anthropic"
|
"nex/backend/internal/conversion/anthropic"
|
||||||
openaiConv "nex/backend/internal/conversion/openai"
|
openaiConv "nex/backend/internal/conversion/openai"
|
||||||
@@ -40,25 +35,20 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
|
|||||||
w.Write([]byte(`{"error":"not mocked"}`))
|
w.Write([]byte(`{"error":"not mocked"}`))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
dir, _ := os.MkdirTemp("", "e2e-test-*")
|
db := setupTestDB(t)
|
||||||
db, err := gorm.Open(sqlite.Open(filepath.Join(dir, "test.db")), &gorm.Config{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
sqlDB, _ := db.DB()
|
sqlDB, _ := db.DB()
|
||||||
if sqlDB != nil {
|
if sqlDB != nil {
|
||||||
sqlDB.Close()
|
sqlDB.Close()
|
||||||
}
|
}
|
||||||
upstream.Close()
|
upstream.Close()
|
||||||
os.RemoveAll(dir)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
statsRepo := repository.NewStatsRepository(db)
|
statsRepo := repository.NewStatsRepository(db)
|
||||||
|
|
||||||
providerService := service.NewProviderService(providerRepo)
|
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||||
statsService := service.NewStatsService(statsRepo)
|
statsService := service.NewStatsService(statsRepo)
|
||||||
@@ -105,7 +95,7 @@ func e2eCreateProviderAndModel(t *testing.T, r *gin.Engine, providerID, protocol
|
|||||||
require.Equal(t, 201, w.Code)
|
require.Equal(t, 201, w.Code)
|
||||||
|
|
||||||
modelBody, _ := json.Marshal(map[string]string{
|
modelBody, _ := json.Marshal(map[string]string{
|
||||||
"id": modelName, "provider_id": providerID, "model_name": modelName,
|
"provider_id": providerID, "model_name": modelName,
|
||||||
})
|
})
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||||
@@ -178,10 +168,10 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{
|
"messages": []map[string]any{
|
||||||
{"role": "user", "content": "你好"},
|
{"role": "user", "content": "你好"},
|
||||||
},
|
},
|
||||||
@@ -195,7 +185,7 @@ func TestE2E_OpenAI_NonStream_BasicText(t *testing.T) {
|
|||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
assert.Equal(t, "chat.completion", resp["object"])
|
assert.Equal(t, "chat.completion", resp["object"])
|
||||||
assert.Equal(t, "gpt-4o", resp["model"])
|
assert.Equal(t, "openai_p/gpt-4o", resp["model"])
|
||||||
|
|
||||||
choices := resp["choices"].([]any)
|
choices := resp["choices"].([]any)
|
||||||
require.Len(t, choices, 1)
|
require.Len(t, choices, 1)
|
||||||
@@ -231,10 +221,10 @@ func TestE2E_OpenAI_NonStream_MultiTurn(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
|
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{
|
"messages": []map[string]any{
|
||||||
{"role": "system", "content": "你是编程助手"},
|
{"role": "system", "content": "你是编程助手"},
|
||||||
{"role": "user", "content": "什么是interface?"},
|
{"role": "user", "content": "什么是interface?"},
|
||||||
@@ -279,10 +269,10 @@ func TestE2E_OpenAI_NonStream_ToolCalls(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
|
"usage": map[string]any{"prompt_tokens": 80, "completion_tokens": 18, "total_tokens": 98},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{
|
"messages": []map[string]any{
|
||||||
{"role": "user", "content": "北京天气"},
|
{"role": "user", "content": "北京天气"},
|
||||||
},
|
},
|
||||||
@@ -335,10 +325,10 @@ func TestE2E_OpenAI_NonStream_MaxTokens_Length(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
|
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 30, "total_tokens": 50},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
|
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
|
||||||
"max_tokens": 30,
|
"max_tokens": 30,
|
||||||
})
|
})
|
||||||
@@ -372,10 +362,10 @@ func TestE2E_OpenAI_NonStream_UsageWithReasoning(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "o3", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "o3", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "o3",
|
"model": "openai_p/o3",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
|
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -413,10 +403,10 @@ func TestE2E_OpenAI_NonStream_Refusal(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
|
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 35, "total_tokens": 47},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "做坏事"}},
|
"messages": []map[string]any{{"role": "user", "content": "做坏事"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -455,10 +445,10 @@ func TestE2E_OpenAI_Stream_Text(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
})
|
})
|
||||||
@@ -499,10 +489,10 @@ func TestE2E_OpenAI_Stream_ToolCalls(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@@ -548,10 +538,10 @@ func TestE2E_OpenAI_Stream_WithUsage(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
})
|
})
|
||||||
@@ -583,10 +573,10 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
|
"usage": map[string]any{"input_tokens": 15, "output_tokens": 25},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -599,7 +589,7 @@ func TestE2E_Anthropic_NonStream_BasicText(t *testing.T) {
|
|||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
assert.Equal(t, "message", resp["type"])
|
assert.Equal(t, "message", resp["type"])
|
||||||
assert.Equal(t, "assistant", resp["role"])
|
assert.Equal(t, "assistant", resp["role"])
|
||||||
assert.Equal(t, "claude-opus-4-7", resp["model"])
|
assert.Equal(t, "anthropic_p/claude-opus-4-7", resp["model"])
|
||||||
assert.Equal(t, "end_turn", resp["stop_reason"])
|
assert.Equal(t, "end_turn", resp["stop_reason"])
|
||||||
|
|
||||||
content := resp["content"].([]any)
|
content := resp["content"].([]any)
|
||||||
@@ -629,10 +619,10 @@ func TestE2E_Anthropic_NonStream_WithSystem(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
|
"usage": map[string]any{"input_tokens": 30, "output_tokens": 15},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"system": "你是编程助手",
|
"system": "你是编程助手",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "什么是递归?"}},
|
"messages": []map[string]any{{"role": "user", "content": "什么是递归?"}},
|
||||||
})
|
})
|
||||||
@@ -658,10 +648,10 @@ func TestE2E_Anthropic_NonStream_ToolUse(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
|
"usage": map[string]any{"input_tokens": 180, "output_tokens": 42},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"name": "get_weather", "description": "获取天气",
|
"name": "get_weather", "description": "获取天气",
|
||||||
@@ -704,10 +694,10 @@ func TestE2E_Anthropic_NonStream_Thinking(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
|
"usage": map[string]any{"input_tokens": 95, "output_tokens": 280},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 4096,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
|
"messages": []map[string]any{{"role": "user", "content": "15+23*2=?"}},
|
||||||
"thinking": map[string]any{"type": "enabled", "budget_tokens": 2048},
|
"thinking": map[string]any{"type": "enabled", "budget_tokens": 2048},
|
||||||
})
|
})
|
||||||
@@ -736,10 +726,10 @@ func TestE2E_Anthropic_NonStream_MaxTokens(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
|
"usage": map[string]any{"input_tokens": 22, "output_tokens": 20},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 20,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 20,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
|
"messages": []map[string]any{{"role": "user", "content": "介绍AI历史"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -764,10 +754,10 @@ func TestE2E_Anthropic_NonStream_StopSequence(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
|
"usage": map[string]any{"input_tokens": 22, "output_tokens": 10},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
|
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
|
||||||
"stop_sequences": []string{"5"},
|
"stop_sequences": []string{"5"},
|
||||||
})
|
})
|
||||||
@@ -800,10 +790,10 @@ func TestE2E_Anthropic_NonStream_MetadataUserID(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
|
"usage": map[string]any{"input_tokens": 12, "output_tokens": 5},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||||
"metadata": map[string]any{"user_id": "user_12345"},
|
"metadata": map[string]any{"user_id": "user_12345"},
|
||||||
})
|
})
|
||||||
@@ -829,10 +819,10 @@ func TestE2E_Anthropic_NonStream_UsageWithCache(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
|
"system": []map[string]any{{"type": "text", "text": "你是编程助手。"}},
|
||||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||||
})
|
})
|
||||||
@@ -874,10 +864,10 @@ func TestE2E_Anthropic_Stream_Text(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
"messages": []map[string]any{{"role": "user", "content": "你好"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
})
|
})
|
||||||
@@ -921,10 +911,10 @@ func TestE2E_Anthropic_Stream_Thinking(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 4096,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 4096,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "1+1=?"}},
|
"messages": []map[string]any{{"role": "user", "content": "1+1=?"}},
|
||||||
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
|
"thinking": map[string]any{"type": "enabled", "budget_tokens": 1024},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
@@ -970,10 +960,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_RequestFormat(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
|
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-model",
|
"model": "anthropic_p/claude-model",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1011,10 +1001,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_RequestFormat(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4", "max_tokens": 1024,
|
"model": "openai_p/gpt-4", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1052,10 +1042,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-model",
|
"model": "anthropic_p/claude-model",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
})
|
})
|
||||||
@@ -1092,10 +1082,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4", "max_tokens": 1024,
|
"model": "openai_p/gpt-4", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
"messages": []map[string]any{{"role": "user", "content": "Hello"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
})
|
})
|
||||||
@@ -1130,10 +1120,10 @@ func TestE2E_OpenAI_ErrorResponse(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "nonexistent", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "nonexistent", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "nonexistent",
|
"model": "openai_p/nonexistent",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1157,10 +1147,10 @@ func TestE2E_Anthropic_ErrorResponse(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1203,10 +1193,10 @@ func TestE2E_OpenAI_NonStream_ParallelToolCalls(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 36, "total_tokens": 136},
|
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 36, "total_tokens": 136},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@@ -1255,10 +1245,10 @@ func TestE2E_OpenAI_NonStream_StopSequence(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
"usage": map[string]any{"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
|
"messages": []map[string]any{{"role": "user", "content": "从1数到10"}},
|
||||||
"stop": []string{"5"},
|
"stop": []string{"5"},
|
||||||
})
|
})
|
||||||
@@ -1293,10 +1283,10 @@ func TestE2E_OpenAI_NonStream_ContentFilter(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 8, "completion_tokens": 0, "total_tokens": 8},
|
"usage": map[string]any{"prompt_tokens": 8, "completion_tokens": 0, "total_tokens": 8},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "危险内容"}},
|
"messages": []map[string]any{{"role": "user", "content": "危险内容"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1325,10 +1315,10 @@ func TestE2E_Anthropic_NonStream_MultiToolUse(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 200, "output_tokens": 84},
|
"usage": map[string]any{"input_tokens": 200, "output_tokens": 84},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京和上海的天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"name": "get_weather", "description": "获取天气",
|
"name": "get_weather", "description": "获取天气",
|
||||||
@@ -1374,10 +1364,10 @@ func TestE2E_Anthropic_NonStream_ToolChoiceAny(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
|
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "现在几点了?"}},
|
"messages": []map[string]any{{"role": "user", "content": "现在几点了?"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"name": "get_time", "description": "获取当前时间",
|
"name": "get_time", "description": "获取当前时间",
|
||||||
@@ -1417,10 +1407,10 @@ func TestE2E_Anthropic_NonStream_ArraySystemPrompt(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
|
"usage": map[string]any{"input_tokens": 50, "output_tokens": 10},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"system": []map[string]any{
|
"system": []map[string]any{
|
||||||
{"type": "text", "text": "你是编程助手。"},
|
{"type": "text", "text": "你是编程助手。"},
|
||||||
{"type": "text", "text": "请用中文回答。"},
|
{"type": "text", "text": "请用中文回答。"},
|
||||||
@@ -1454,10 +1444,10 @@ func TestE2E_Anthropic_NonStream_ToolResultMessage(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
|
"usage": map[string]any{"input_tokens": 150, "output_tokens": 20},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{
|
"messages": []map[string]any{
|
||||||
{"role": "user", "content": "北京天气"},
|
{"role": "user", "content": "北京天气"},
|
||||||
{"role": "assistant", "content": []map[string]any{
|
{"role": "assistant", "content": []map[string]any{
|
||||||
@@ -1507,10 +1497,10 @@ func TestE2E_Anthropic_Stream_ToolCalls(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"name": "get_weather", "description": "获取天气",
|
"name": "get_weather", "description": "获取天气",
|
||||||
@@ -1561,10 +1551,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_NonStream_ToolCalls(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
|
"usage": map[string]any{"input_tokens": 100, "output_tokens": 30},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-model",
|
"model": "anthropic_p/claude-model",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@@ -1613,10 +1603,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_NonStream_Thinking(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150},
|
"usage": map[string]any{"prompt_tokens": 50, "completion_tokens": 100, "total_tokens": 150},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4", "max_tokens": 4096,
|
"model": "openai_p/gpt-4", "max_tokens": 4096,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "宇宙的答案"}},
|
"messages": []map[string]any{{"role": "user", "content": "宇宙的答案"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1643,10 +1633,10 @@ func TestE2E_CrossProtocol_StopReasonMapping(t *testing.T) {
|
|||||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 20},
|
"usage": map[string]any{"input_tokens": 10, "output_tokens": 20},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-model",
|
"model": "anthropic_p/claude-model",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "长文"}},
|
"messages": []map[string]any{{"role": "user", "content": "长文"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1685,10 +1675,10 @@ func TestE2E_OpenAI_NonStream_AssistantWithToolResult(t *testing.T) {
|
|||||||
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
|
"usage": map[string]any{"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{
|
"messages": []map[string]any{
|
||||||
{"role": "user", "content": "北京天气"},
|
{"role": "user", "content": "北京天气"},
|
||||||
{"role": "assistant", "content": nil, "tool_calls": []map[string]any{{
|
{"role": "assistant", "content": nil, "tool_calls": []map[string]any{{
|
||||||
@@ -1732,10 +1722,10 @@ func TestE2E_CrossProtocol_AnthropicToOpenAI_Stream_ToolCalls(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-model", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-model", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-model",
|
"model": "anthropic_p/claude-model",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@@ -1781,10 +1771,10 @@ func TestE2E_CrossProtocol_OpenAIToAnthropic_Stream_ToolCalls(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4", "max_tokens": 1024,
|
"model": "openai_p/gpt-4", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
"messages": []map[string]any{{"role": "user", "content": "北京天气"}},
|
||||||
"tools": []map[string]any{{
|
"tools": []map[string]any{{
|
||||||
"name": "get_weather", "description": "获取天气",
|
"name": "get_weather", "description": "获取天气",
|
||||||
@@ -1819,10 +1809,10 @@ func TestE2E_OpenAI_Upstream5xx_ErrorPassthrough(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "openai-p", "openai", "gpt-4o", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "openai_p", "openai", "gpt-4o", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "gpt-4o",
|
"model": "openai_p/gpt-4o",
|
||||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1851,10 +1841,10 @@ func TestE2E_Anthropic_Upstream5xx_ErrorPassthrough(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||||
})
|
})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1889,10 +1879,10 @@ func TestE2E_Anthropic_Stream_TruncatedSSE(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
e2eCreateProviderAndModel(t, r, "anthropic-p", "anthropic", "claude-opus-4-7", upstream.URL)
|
e2eCreateProviderAndModel(t, r, "anthropic_p", "anthropic", "claude-opus-4-7", upstream.URL)
|
||||||
|
|
||||||
body, _ := json.Marshal(map[string]any{
|
body, _ := json.Marshal(map[string]any{
|
||||||
"model": "claude-opus-4-7", "max_tokens": 1024,
|
"model": "anthropic_p/claude-opus-4-7", "max_tokens": 1024,
|
||||||
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
"messages": []map[string]any{{"role": "user", "content": "test"}},
|
||||||
"stream": true,
|
"stream": true,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,11 +9,8 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"nex/backend/internal/config"
|
|
||||||
"nex/backend/internal/domain"
|
"nex/backend/internal/domain"
|
||||||
"nex/backend/internal/handler"
|
"nex/backend/internal/handler"
|
||||||
"nex/backend/internal/handler/middleware"
|
"nex/backend/internal/handler/middleware"
|
||||||
@@ -27,23 +24,13 @@ func init() {
|
|||||||
|
|
||||||
func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
|
func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dir := t.TempDir()
|
db := setupTestDB(t)
|
||||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
sqlDB, _ := db.DB()
|
|
||||||
if sqlDB != nil {
|
|
||||||
sqlDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
providerRepo := repository.NewProviderRepository(db)
|
providerRepo := repository.NewProviderRepository(db)
|
||||||
modelRepo := repository.NewModelRepository(db)
|
modelRepo := repository.NewModelRepository(db)
|
||||||
statsRepo := repository.NewStatsRepository(db)
|
statsRepo := repository.NewStatsRepository(db)
|
||||||
|
|
||||||
providerService := service.NewProviderService(providerRepo)
|
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||||
_ = service.NewRoutingService(modelRepo, providerRepo)
|
_ = service.NewRoutingService(modelRepo, providerRepo)
|
||||||
statsService := service.NewStatsService(statsRepo)
|
statsService := service.NewStatsService(statsRepo)
|
||||||
@@ -97,13 +84,16 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
|
|||||||
|
|
||||||
// 2. 创建 Model
|
// 2. 创建 Model
|
||||||
modelBody, _ := json.Marshal(map[string]string{
|
modelBody, _ := json.Marshal(map[string]string{
|
||||||
"id": "gpt4", "provider_id": "openai", "model_name": "gpt-4",
|
"provider_id": "openai", "model_name": "gpt-4",
|
||||||
})
|
})
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
assert.Equal(t, 201, w.Code)
|
assert.Equal(t, 201, w.Code)
|
||||||
|
var createdModel domain.Model
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &createdModel)
|
||||||
|
assert.NotEmpty(t, createdModel.ID)
|
||||||
|
|
||||||
// 3. 列出 Provider
|
// 3. 列出 Provider
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
@@ -135,7 +125,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
|
|||||||
|
|
||||||
// 6. 删除 Model
|
// 6. 删除 Model
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest("DELETE", "/api/models/gpt4", nil)
|
req = httptest.NewRequest("DELETE", "/api/models/"+createdModel.ID, nil)
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
assert.Equal(t, 204, w.Code)
|
assert.Equal(t, 204, w.Code)
|
||||||
|
|
||||||
@@ -160,17 +150,19 @@ func TestAnthropic_ModelCreation(t *testing.T) {
|
|||||||
assert.Equal(t, 201, w.Code)
|
assert.Equal(t, 201, w.Code)
|
||||||
|
|
||||||
modelBody, _ := json.Marshal(map[string]string{
|
modelBody, _ := json.Marshal(map[string]string{
|
||||||
"id": "claude3", "provider_id": "anthropic", "model_name": "claude-3-opus-20240229",
|
"provider_id": "anthropic", "model_name": "claude-3-opus-20240229",
|
||||||
})
|
})
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
assert.Equal(t, 201, w.Code)
|
assert.Equal(t, 201, w.Code)
|
||||||
|
var createdModel domain.Model
|
||||||
|
json.Unmarshal(w.Body.Bytes(), &createdModel)
|
||||||
|
|
||||||
// 验证创建成功
|
// 验证创建成功
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest("GET", "/api/models/claude3", nil)
|
req = httptest.NewRequest("GET", "/api/models/"+createdModel.ID, nil)
|
||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, 200, w.Code)
|
||||||
}
|
}
|
||||||
@@ -188,7 +180,7 @@ func TestStats_RecordingAndQuery(t *testing.T) {
|
|||||||
r.ServeHTTP(w, req)
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
modelBody, _ := json.Marshal(map[string]string{
|
modelBody, _ := json.Marshal(map[string]string{
|
||||||
"id": "m1", "provider_id": "p1", "model_name": "gpt-4",
|
"provider_id": "p1", "model_name": "gpt-4",
|
||||||
})
|
})
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
req = httptest.NewRequest("POST", "/api/models", bytes.NewReader(modelBody))
|
||||||
|
|||||||
37
backend/tests/integration/testhelper.go
Normal file
37
backend/tests/integration/testhelper.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupTestDB 创建内存 SQLite 数据库并执行 AutoMigrate。
|
||||||
|
// 使用 MaxOpenConns(1) 确保 :memory: 模式不会被连接池丢弃。
|
||||||
|
func setupTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
require.NoError(t, err)
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
sqlDB.SetConnMaxLifetime(0)
|
||||||
|
|
||||||
|
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
// 等待异步 goroutine(如 statsService.Record)完成
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
sqlDB.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
79
backend/tests/migration_test.go
Normal file
79
backend/tests/migration_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"nex/backend/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMigration_ModelsUUIDPrimaryKey(t *testing.T) {
|
||||||
|
db := SetupTestDB(t)
|
||||||
|
defer CleanupTestDB(t, db)
|
||||||
|
|
||||||
|
// 创建供应商
|
||||||
|
_ = CreateTestProvider(t, db, "openai")
|
||||||
|
|
||||||
|
// 创建模型使用 UUID 作为 id
|
||||||
|
model, err := CreateTestModel(t, db, "550e8400-e29b-41d4-a716-446655440000", "openai", "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 通过 UUID 查询
|
||||||
|
var result config.Model
|
||||||
|
require.NoError(t, db.First(&result, "id = ?", model.ID).Error)
|
||||||
|
assert.Equal(t, "gpt-4", result.ModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigration_UniqueProviderModelConstraint(t *testing.T) {
|
||||||
|
db := SetupTestDB(t)
|
||||||
|
defer CleanupTestDB(t, db)
|
||||||
|
|
||||||
|
// 创建供应商
|
||||||
|
_ = CreateTestProvider(t, db, "openai")
|
||||||
|
|
||||||
|
// 创建第一个模型
|
||||||
|
_, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 尝试创建相同 (provider_id, model_name) 的模型应失败
|
||||||
|
_, err = CreateTestModel(t, db, "uuid-2", "openai", "gpt-4")
|
||||||
|
assert.Error(t, err, "UNIQUE(provider_id, model_name) 约束应阻止重复")
|
||||||
|
|
||||||
|
// 不同 provider_id 下相同 model_name 应成功
|
||||||
|
_ = CreateTestProvider(t, db, "anthropic")
|
||||||
|
_, err = CreateTestModel(t, db, "uuid-3", "anthropic", "gpt-4")
|
||||||
|
require.NoError(t, err, "不同 provider_id 下相同 model_name 应允许")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigration_CascadeDelete(t *testing.T) {
|
||||||
|
db := SetupTestDB(t)
|
||||||
|
defer CleanupTestDB(t, db)
|
||||||
|
|
||||||
|
// 创建供应商和模型
|
||||||
|
provider := CreateTestProvider(t, db, "openai")
|
||||||
|
_, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 删除供应商应级联删除模型
|
||||||
|
require.NoError(t, db.Delete(&provider).Error)
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
db.Model(&config.Model{}).Where("provider_id = ?", "openai").Count(&count)
|
||||||
|
assert.Equal(t, int64(0), count, "删除供应商后其模型应被级联删除")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigration_ModelDefaultEnabled(t *testing.T) {
|
||||||
|
db := SetupTestDB(t)
|
||||||
|
defer CleanupTestDB(t, db)
|
||||||
|
|
||||||
|
_ = CreateTestProvider(t, db, "openai")
|
||||||
|
|
||||||
|
// 创建模型不指定 enabled,应默认为 true
|
||||||
|
_, err := CreateTestModel(t, db, "uuid-1", "openai", "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var result config.Model
|
||||||
|
require.NoError(t, db.First(&result, "id = ?", "uuid-1").Error)
|
||||||
|
assert.True(t, result.Enabled, "enabled 字段默认应为 true")
|
||||||
|
}
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
schema: spec-driven
|
|
||||||
created: 2026-04-20
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
## Context
|
|
||||||
|
|
||||||
Nex 是一个 AI 网关,屏蔽多个 AI 供应商(OpenAI、Anthropic 等)的差异,提供统一的 API 接口。当前后端直接透传上游供应商的原始模型名称(如 `gpt-4`),通过 `models` 表的 `model_name` 字段路由。`models` 表的 `id` 字段当前语义是用户自定义标识符,与上游模型名 `model_name` 之间没有明确的职责分离。
|
|
||||||
|
|
||||||
当前架构:
|
|
||||||
- `ProxyHandler` 从请求体中提取 `model` 字段 → `RoutingService.Route(modelName)` 按 `model_name` 查询
|
|
||||||
- `GET /v1/models` 直接透传到第一个供应商的上游接口
|
|
||||||
- `GET /v1/models/{id}` 直接透传到上游
|
|
||||||
- `TargetProvider.ModelName` 在 encoder 中覆盖请求体的 `model` 字段
|
|
||||||
|
|
||||||
## Goals / Non-Goals
|
|
||||||
|
|
||||||
**Goals:**
|
|
||||||
- 定义统一模型 ID 格式 `provider_id/model_name`,全局唯一标识一个模型
|
|
||||||
- 拦截 `/v1/models` 和 `/v1/models/{unified_id}` 接口,从数据库聚合返回,不再透传上游
|
|
||||||
- 所有代理接口(Chat、Embeddings、Rerank)使用统一模型 ID 路由,响应中 `model` 字段覆写为统一 ID
|
|
||||||
- `models.id` 改为 UUID(内部标识),`models.model_name` 存储上游供应商的模型名称
|
|
||||||
- `provider_id` 约束为 `[a-zA-Z0-9_]+`,防止特殊字符影响 URL 和 JSON 交互
|
|
||||||
- 保持协议无关、供应商无关的设计
|
|
||||||
|
|
||||||
**Non-Goals:**
|
|
||||||
- 不支持供应商别名或模型别名
|
|
||||||
- 不做上游模型列表自动同步(管理员手动配置可见模型)
|
|
||||||
- 不适配前端(后续统一适配)
|
|
||||||
|
|
||||||
## Decisions
|
|
||||||
|
|
||||||
### D1: 统一模型 ID 格式 — `provider_id/model_name`
|
|
||||||
|
|
||||||
格式: `{provider_id}/{model_name}`,例如 `openai/gpt-4`、`anthropic/claude-3-opus-20240229`。
|
|
||||||
|
|
||||||
- 使用 `strings.SplitN(id, "/", 2)` 解析,只在第一个 `/` 处分割
|
|
||||||
- `provider_id` 约束为 `[a-zA-Z0-9_]+`,保证不含 `/`,解析安全
|
|
||||||
- `model_name`(上游模型名)不受字符约束,因为它不出现在管理 API 的 URL 主键中
|
|
||||||
|
|
||||||
选择此格式而非 `provider_id:model_name`(冒号分隔)的原因:斜杠在 JSON 字符串中天然安全,且在 URL 路径中语义清晰(`/v1/models/openai/gpt-4`),更符合 REST 风格。
|
|
||||||
|
|
||||||
### D2: models 表 schema 变更
|
|
||||||
|
|
||||||
```
|
|
||||||
旧: id(TEXT PK, 用户自定义), provider_id, model_name(上游模型名), enabled, created_at
|
|
||||||
新: id(UUID PK, 自动生成), provider_id, model_name(上游模型名), enabled, created_at
|
|
||||||
UNIQUE(provider_id, model_name)
|
|
||||||
```
|
|
||||||
|
|
||||||
关键语义变化:
|
|
||||||
- `id` 从用户自定义标识符变为 UUID 内部主键(自动生成),用于管理接口 CRUD
|
|
||||||
- `model_name` 语义不变,始终存储上游供应商的模型名称,发给上游的实际值
|
|
||||||
- 新增联合唯一约束 `UNIQUE(provider_id, model_name)` 保证同一供应商内模型不重复
|
|
||||||
|
|
||||||
选择保留 `id` 作为 PK 而非使用 `(provider_id, model_name)` 联合主键的原因:上游模型名可能含 `/` 等特殊字符(如 Azure OpenAI 的 deployment 路径),不适合作为管理接口的 URL 参数。`id` 为 UUID 可以避免所有特殊字符问题。
|
|
||||||
|
|
||||||
### D3: Models/ModelInfo 接口本地聚合
|
|
||||||
|
|
||||||
`GET /v1/models` 从数据库查询所有 `enabled` 的模型(JOIN providers),组装为 `CanonicalModelList`,`ID` 字段使用统一模型 ID,通过客户端协议的 adapter 编码返回。不请求上游。
|
|
||||||
|
|
||||||
`GET /v1/models/{provider_id}/{model_name}` 从 URL 提取统一模型 ID,解析后查询数据库,组装为 `CanonicalModelInfo` 返回。不请求上游。
|
|
||||||
|
|
||||||
选择纯 DB 聚合而非实时查询上游的原因:
|
|
||||||
1. 管理员通过 `/api/models` 控制哪些模型对用户可见,网关的意义在于控制可见性
|
|
||||||
2. 响应速度快,不依赖上游可用性
|
|
||||||
3. 符合当前架构中管理员手动配置 provider 和 model 的设计哲学
|
|
||||||
|
|
||||||
### D4: 跨协议响应 model 字段覆写
|
|
||||||
|
|
||||||
跨协议场景下,上游返回的响应经过 decode → encode 全量转换。上游响应中的 `model` 字段是原生模型名(如 `gpt-4`),需要在返回给客户端前覆写为统一模型 ID。
|
|
||||||
|
|
||||||
实现位置:`ConversionEngine.ConvertHttpResponse` 新增 `modelOverride string` 参数。在解码上游响应到 canonical 后、编码客户端响应前,将 `canonical.Model` 设为 `modelOverride`。流式场景同理,`CreateStreamConverter` 同样接收 `modelOverride` 参数。
|
|
||||||
|
|
||||||
此方案仅在跨协议转换路径使用。选择在 canonical 层面处理的原因:
|
|
||||||
1. 跨协议必须全量 decode → encode,canonical 的 Model 字段天然可覆写
|
|
||||||
2. 不侵入各协议 adapter 的实现
|
|
||||||
3. 与 Smart Passthrough 互补——跨协议不可保真,canonical 覆写是自然的
|
|
||||||
|
|
||||||
### D5: ProtocolAdapter 接口扩展
|
|
||||||
|
|
||||||
在 `ProtocolAdapter` 接口新增四个方法,将所有协议相关的 model 字段知识归属到 adapter:
|
|
||||||
|
|
||||||
1. `ExtractUnifiedModelID(nativePath string) (string, error)` — 从路径中提取统一模型 ID
|
|
||||||
2. `ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)` — 从请求体中提取 model 值(所有流程复用,替代 handler 层硬编码的 `extractModelName`)
|
|
||||||
3. `RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)` — 最小化 JSON 改写请求体中的 model 字段(Smart Passthrough 请求方向)
|
|
||||||
4. `RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)` — 最小化 JSON 改写响应体中的 model 字段(Smart Passthrough 响应方向)
|
|
||||||
|
|
||||||
拆分请求/响应方向的原因:请求体和响应体的 JSON 结构可能不同,model 字段的位置可能不同(当前 OpenAI/Anthropic 协议碰巧都在顶层 `"model"`,但未来协议不一定)。拆分后 adapter 各自独立实现,各自按 ifaceType 分派。
|
|
||||||
|
|
||||||
`ExtractModelName` 和两个 `Rewrite*` 方法均接收 `InterfaceType` 参数,因为不同接口类型的请求体/响应体结构可能不同,adapter 按 ifaceType 分派具体的定位和改写逻辑。
|
|
||||||
|
|
||||||
对于 `isModelInfoPath` 的调整:允许 suffix 中包含 `/`,因为统一模型 ID 格式为 `provider_id/model_name`。
|
|
||||||
|
|
||||||
将此方法放在适配器接口而非 handler 中通用实现的原因:不同协议的模型详情路径格式和请求体结构可能不同,各自拥有独立演进能力。
|
|
||||||
|
|
||||||
### D6: provider_id 字符集约束
|
|
||||||
|
|
||||||
创建供应商时校验 `id` 字段必须匹配 `^[a-zA-Z0-9_]+$`,长度 1-64。
|
|
||||||
|
|
||||||
选择严格限制而非仅排除 `/` 的原因:统一模型 ID 出现在 URL 路径和 JSON 中,`?`、`#`、`&`、`=` 等字符会在 URL 中引起解析问题。限制为字母数字下划线后,URL 中永远安全,不需要编码。
|
|
||||||
|
|
||||||
### D7: pkg/modelid 工具包
|
|
||||||
|
|
||||||
新增 `pkg/modelid` 包,提供:
|
|
||||||
- `ParseUnifiedModelID(id string) (providerID, modelName string, error)` — 解析
|
|
||||||
- `FormatUnifiedModelID(providerID, modelName string) string` — 格式化
|
|
||||||
- `ValidateProviderID(id string) error` — 校验供应商 ID
|
|
||||||
- `IsValidUnifiedModelID(id string) bool` — 校验统一模型 ID
|
|
||||||
|
|
||||||
使用标准库 `strings.SplitN` 和 `regexp` 实现,不引入新依赖。
|
|
||||||
|
|
||||||
### D8: 同协议 Smart Passthrough
|
|
||||||
|
|
||||||
当前同协议透传将请求体原样转发,跳过 decode → encode,保持参数完全保真。但统一模型 ID 要求改写 model 字段,原样透传无法满足。
|
|
||||||
|
|
||||||
**Smart Passthrough**:保留同协议透传的保真优势,通过 `json.RawMessage` 做最小化改写。
|
|
||||||
|
|
||||||
实现方式:adapter 的 `RewriteRequestModelName` 和 `RewriteResponseModelName` 方法各自解析 JSON 为 `map[string]json.RawMessage`,只替换 model 字段的 value,其余字段保留原始 bytes,不经过任何类型转换。参数保真、不丢精度、不改字段顺序。
|
|
||||||
|
|
||||||
各接口类型策略:
|
|
||||||
- Chat/Embedding/Rerank(同协议):Smart Passthrough — 请求改写 model(统一 ID → 上游名),响应改写 model(上游名 → 统一 ID)
|
|
||||||
- Chat/Embedding/Rerank(跨协议):全量 decode → encode + modelOverride
|
|
||||||
- Models/ModelInfo:本地数据库聚合,不请求上游
|
|
||||||
- Passthrough(未知路径):原样透传,不改写 model
|
|
||||||
|
|
||||||
选择让 adapter 拥有完整协议知识(而非通用 json hack)的原因:
|
|
||||||
1. 不同协议的 model 字段位置可能不同,adapter 按 InterfaceType 分派
|
|
||||||
2. 请求和响应的 model 字段位置可能不同,拆分 RewriteRequestModelName/RewriteResponseModelName 各自独立实现
|
|
||||||
3. adapter 内部实现 `ExtractModelName` 和两个 `Rewrite*` 方法可共享同一份"model 在哪"的定位逻辑
|
|
||||||
4. 所有流程复用 `ExtractModelName`,同协议额外复用 `RewriteRequestModelName` + `RewriteResponseModelName`
|
|
||||||
|
|
||||||
## Risks / Trade-offs
|
|
||||||
|
|
||||||
- **[BREAKING CHANGE]** 代理接口 model 字段格式变更,现有客户端必须适配 → 统一 ID 格式简单直观,服务尚未上线无旧客户端
|
|
||||||
- **[联合唯一约束]** 同一供应商下相同 model_name 不允许重复 → 这是正确的行为,语义上就不应该重复
|
|
||||||
- **[model_name 含特殊字符]** 上游模型名可能含 `/`(如 Azure deployment 路径)→ 解析用 `SplitN("/", 2)` 安全,管理接口用 `id` 定位不受影响,代理接口中统一 ID 出现在 JSON body 和 URL 路径中均安全
|
|
||||||
- **[流式响应覆写]** 同协议流式场景需逐 SSE chunk 调用 RewriteResponseModelName → 每个 chunk 多一次轻量 JSON 解析,用 json.RawMessage 保证开销极小
|
|
||||||
|
|
||||||
## Open Questions
|
|
||||||
|
|
||||||
无。所有关键决策已在探索阶段确认。
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
## Why
|
|
||||||
|
|
||||||
当前网关直接透传上游供应商的原始模型名称(如 `gpt-4`、`claude-3-opus`),无法在多供应商场景下唯一标识一个模型。不同供应商可能存在同名模型,客户端无法区分应路由到哪个供应商。网关作为屏蔽供应商差异的统一入口,需要定义自有的模型标识体系,让客户端通过统一的 model ID 访问任意供应商的模型,同时拦截 `/v1/models` 等模型查询接口,聚合所有供应商的模型信息返回。
|
|
||||||
|
|
||||||
## What Changes
|
|
||||||
|
|
||||||
- **BREAKING**: 引入统一模型 ID 格式 `provider_id/model_name`(如 `openai/gpt-4`),所有代理接口(Chat、Embeddings、Rerank)的 `model` 字段必须使用此格式
|
|
||||||
- **BREAKING**: `models` 表主键 `id` 改为 UUID 自动生成(不再由用户提供),`model_name` 字段语义保持不变(存储上游供应商模型名称),新增 `UNIQUE(provider_id, model_name)` 联合唯一约束
|
|
||||||
- **BREAKING**: `provider_id` 限制为 `[a-zA-Z0-9_]+` 字符集,禁止特殊字符
|
|
||||||
- `GET /v1/models` 改为从数据库聚合返回所有已启用模型,不再透传到上游供应商
|
|
||||||
- `GET /v1/models/{unified_id}` 改为从数据库查询返回模型详情,不再透传到上游供应商
|
|
||||||
- 同协议透传改为 Smart Passthrough:通过 `json.RawMessage` 最小化改写 model 字段,保持其余参数完全保真
|
|
||||||
- 跨协议转换路径:通过 canonical 层面 modelOverride 参数覆写响应 model 字段
|
|
||||||
- 管理 API (`/api/models`) 请求体字段适配,响应中新增 `unified_id` 字段
|
|
||||||
- 新增 `pkg/modelid` 工具包,提供统一模型 ID 的解析、格式化、校验
|
|
||||||
- ProtocolAdapter 接口新增 `ExtractUnifiedModelID`、`ExtractModelName`、`RewriteRequestModelName`、`RewriteResponseModelName` 方法,协议无关地处理 model 字段
|
|
||||||
|
|
||||||
## Capabilities
|
|
||||||
|
|
||||||
### New Capabilities
|
|
||||||
- `unified-model-id`: 统一模型 ID 的解析、格式化、校验工具包,以及 `provider_id` 字符集约束
|
|
||||||
|
|
||||||
### Modified Capabilities
|
|
||||||
- `model-management`: 模型表结构调整(id 改 UUID 自动生成、新增联合唯一约束),CRUD 接口字段变更(创建不再提供 id)
|
|
||||||
- `provider-management`: provider_id 创建时增加字符集校验(`[a-zA-Z0-9_]+`)
|
|
||||||
- `unified-proxy-handler`: 统一模型 ID 解析路由、Models/ModelInfo 接口改为本地聚合、同协议 Smart Passthrough、跨协议 modelOverride 覆写
|
|
||||||
- `conversion-engine`: 跨协议场景下 ConvertHttpResponse 支持 model 覆写参数
|
|
||||||
- `protocol-adapter-openai`: isModelInfoPath 适配含 `/` 路径、新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName
|
|
||||||
- `protocol-adapter-anthropic`: isModelInfoPath 适配含 `/` 路径、新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName
|
|
||||||
- `request-validation`: provider_id 字符集校验规则、模型创建校验适配
|
|
||||||
- `database-migration`: models 表 schema 变更迁移(DROP + CREATE 重建)
|
|
||||||
- `usage-statistics`: 明确统计记录使用 providerID + modelName 的上游模型名
|
|
||||||
|
|
||||||
## Impact
|
|
||||||
|
|
||||||
- **数据库**: models 表 schema 变更(DROP + CREATE 重建)
|
|
||||||
- **API 兼容性**: 代理接口 model 字段格式为 BREAKING CHANGE,需客户端适配
|
|
||||||
- **管理 API**: `/api/models` 请求体变更(创建不再提供 id,自动生成 UUID),响应新增 unified_id 字段
|
|
||||||
- **代码模块**: domain、repository、service、handler、conversion、adapter 层均有改动
|
|
||||||
- **测试**: routing service、proxy handler、adapter、model handler 需要新增/更新测试
|
|
||||||
- **前端**: 本次变更不涉及前端适配,前端后续统一适配
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: 跨协议响应转换支持 model 覆写
|
|
||||||
|
|
||||||
ConversionEngine SHALL 在跨协议响应转换时支持 model 字段覆写。
|
|
||||||
|
|
||||||
#### Scenario: ConvertHttpResponse 接收 modelOverride 参数
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ConvertHttpResponse` 时传入 `modelOverride` 参数(跨协议场景,非空字符串)
|
|
||||||
- **THEN** SHALL 在解码上游响应到 canonical 后,将 `Model` 字段设为 `modelOverride`
|
|
||||||
- **THEN** SHALL 使用覆写后的 canonical 编码为客户端协议格式
|
|
||||||
|
|
||||||
#### Scenario: modelOverride 为空
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ConvertHttpResponse` 时 `modelOverride` 为空字符串
|
|
||||||
- **THEN** SHALL NOT 覆写 canonical 的 Model 字段,保持上游原始值
|
|
||||||
|
|
||||||
#### Scenario: Chat 响应 model 覆写
|
|
||||||
|
|
||||||
- **WHEN** 跨协议转换 Chat 类型响应且 `modelOverride` 非空
|
|
||||||
- **THEN** `CanonicalResponse.Model` SHALL 被设为 `modelOverride`
|
|
||||||
|
|
||||||
#### Scenario: Embedding 响应 model 覆写
|
|
||||||
|
|
||||||
- **WHEN** 跨协议转换 Embedding 类型响应且 `modelOverride` 非空
|
|
||||||
- **THEN** `CanonicalEmbeddingResponse.Model` SHALL 被设为 `modelOverride`
|
|
||||||
|
|
||||||
#### Scenario: Rerank 响应 model 覆写
|
|
||||||
|
|
||||||
- **WHEN** 跨协议转换 Rerank 类型响应且 `modelOverride` 非空
|
|
||||||
- **THEN** `CanonicalRerankResponse.Model` SHALL 被设为 `modelOverride`
|
|
||||||
|
|
||||||
### Requirement: 跨协议流式转换支持 model 覆写
|
|
||||||
|
|
||||||
ConversionEngine SHALL 在跨协议流式转换时支持 model 字段覆写。
|
|
||||||
|
|
||||||
#### Scenario: CreateStreamConverter 接收 modelOverride 参数
|
|
||||||
|
|
||||||
- **WHEN** 调用 `CreateStreamConverter` 时传入 `modelOverride` 参数(跨协议场景)
|
|
||||||
- **THEN** SHALL 在流式 canonical 事件中将 `Model` 字段设为 `modelOverride`
|
|
||||||
|
|
||||||
### Requirement: TargetProvider 字段语义
|
|
||||||
|
|
||||||
TargetProvider 的 ModelName 字段 SHALL 存储上游供应商的模型名称(即 `model_name` 字段值),语义保持不变。
|
|
||||||
|
|
||||||
#### Scenario: encoder 使用 TargetProvider.ModelName
|
|
||||||
|
|
||||||
- **WHEN** 协议适配器编码请求时
|
|
||||||
- **THEN** SHALL 使用 `TargetProvider.ModelName` 作为发给上游的 `model` 字段值(值为路由结果中的 model_name)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: models 表 schema 变更
|
|
||||||
|
|
||||||
系统 SHALL 通过迁移脚本重建 models 表结构(服务未上线,无需考虑数据迁移)。
|
|
||||||
|
|
||||||
#### Scenario: 迁移后 models 表结构
|
|
||||||
|
|
||||||
- **WHEN** 执行迁移
|
|
||||||
- **THEN** SHALL 先 DROP 已有的 models 表(无旧数据)
|
|
||||||
- **THEN** SHALL CREATE 新的 models 表,包含字段:id(TEXT PRIMARY KEY)、provider_id(TEXT NOT NULL)、model_name(TEXT NOT NULL)、enabled(INTEGER DEFAULT 1)、created_at(DATETIME)
|
|
||||||
- **THEN** SHALL 存在 UNIQUE(provider_id, model_name) 约束
|
|
||||||
- **THEN** SHALL 存在 FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: 创建模型配置
|
|
||||||
|
|
||||||
网关 SHALL 允许为供应商创建新的模型配置。
|
|
||||||
|
|
||||||
#### Scenario: 使用有效数据创建模型
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models` 发送 POST 请求,携带有效的模型数据(provider_id, model_name),不提供 id 字段
|
|
||||||
- **THEN** 网关 SHALL 自动生成 UUID 作为模型 id
|
|
||||||
- **THEN** 网关 SHALL 在数据库中创建新的模型记录
|
|
||||||
- **THEN** 网关 SHALL 返回创建的模型,状态码为 201
|
|
||||||
- **THEN** 模型 SHALL 默认启用
|
|
||||||
- **THEN** 返回的模型 SHALL 包含 `unified_id` 字段,值为 `{provider_id}/{model_name}`
|
|
||||||
|
|
||||||
#### Scenario: 使用不存在的供应商创建模型
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models` 发送 POST 请求,携带不存在的 provider_id
|
|
||||||
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
|
||||||
- **THEN** 错误 SHALL 指示供应商不存在
|
|
||||||
|
|
||||||
#### Scenario: 创建重复模型
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models` 发送 POST 请求,携带已存在的 provider_id + model_name 组合
|
|
||||||
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
|
|
||||||
- **THEN** 错误 SHALL 指示该供应商下已存在相同模型
|
|
||||||
|
|
||||||
### Requirement: 列出所有模型
|
|
||||||
|
|
||||||
网关 SHALL 允许获取所有模型配置。
|
|
||||||
|
|
||||||
#### Scenario: 成功列出模型
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models` 发送 GET 请求
|
|
||||||
- **THEN** 网关 SHALL 返回所有模型的列表
|
|
||||||
- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, unified_id, enabled, created_at
|
|
||||||
|
|
||||||
**变更说明:** 响应新增 unified_id 字段,移除旧语义的 id 自定义输入。
|
|
||||||
|
|
||||||
### Requirement: 按供应商列出模型
|
|
||||||
|
|
||||||
网关 SHALL 允许获取特定供应商的模型。
|
|
||||||
|
|
||||||
#### Scenario: 列出存在供应商的模型
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models?provider_id=<provider_id>` 发送 GET 请求
|
|
||||||
- **THEN** 网关 SHALL 返回指定供应商的模型列表
|
|
||||||
- **THEN** 每个模型 SHALL 包含 unified_id 字段
|
|
||||||
|
|
||||||
### Requirement: 更新模型配置
|
|
||||||
|
|
||||||
网关 SHALL 允许更新现有模型配置。
|
|
||||||
|
|
||||||
#### Scenario: 使用有效数据更新模型
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据
|
|
||||||
- **THEN** 网关 SHALL 更新数据库中的模型记录
|
|
||||||
- **THEN** 网关 SHALL 返回更新后的模型
|
|
||||||
- **THEN** 返回的模型 SHALL 包含更新后的 unified_id
|
|
||||||
|
|
||||||
#### Scenario: 更新模型为重复组合
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,更新 provider_id 或 model_name 导致与已有记录重复
|
|
||||||
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
|
|
||||||
|
|
||||||
### Requirement: 删除模型配置
|
|
||||||
|
|
||||||
网关 SHALL 允许删除模型配置。
|
|
||||||
|
|
||||||
#### Scenario: 删除存在的模型
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models/:id` 发送 DELETE 请求,携带有效的模型 ID
|
|
||||||
- **THEN** 网关 SHALL 删除模型记录
|
|
||||||
- **THEN** 网关 SHALL 返回状态码 204 (No Content)
|
|
||||||
|
|
||||||
### Requirement: 使用 service 层处理业务逻辑
|
|
||||||
|
|
||||||
Handler SHALL 通过 ModelService 处理业务逻辑。
|
|
||||||
|
|
||||||
#### Scenario: 调用 service 方法
|
|
||||||
|
|
||||||
- **WHEN** handler 收到请求
|
|
||||||
- **THEN** SHALL 调用对应的 ModelService 方法(Create、Get、List、Update、Delete)
|
|
||||||
- **THEN** SHALL 使用 domain.Model 类型
|
|
||||||
- **THEN** Create 时 SHALL 调用 `uuid.New()` 生成 id
|
|
||||||
|
|
||||||
#### Scenario: 供应商验证和唯一性校验
|
|
||||||
|
|
||||||
- **WHEN** 创建或更新模型
|
|
||||||
- **THEN** SHALL 在 service 层验证供应商存在
|
|
||||||
- **THEN** SHALL 在 service 层验证 provider_id + model_name 联合唯一
|
|
||||||
|
|
||||||
### Requirement: 使用 repository 层访问数据
|
|
||||||
|
|
||||||
Service SHALL 通过 ModelRepository 访问数据。
|
|
||||||
|
|
||||||
#### Scenario: 联合查询
|
|
||||||
|
|
||||||
- **WHEN** service 需要按 provider 和 model_name 查询模型
|
|
||||||
- **THEN** SHALL 调用 `FindByProviderAndModelName(providerID, modelName)` 方法
|
|
||||||
|
|
||||||
#### Scenario: 查询所有启用模型
|
|
||||||
|
|
||||||
- **WHEN** proxy handler 需要聚合模型列表
|
|
||||||
- **THEN** SHALL 调用 `ListEnabled()` 方法,返回所有 enabled 的模型(关联 enabled 的供应商)
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: 模型详情路径识别
|
|
||||||
|
|
||||||
Anthropic 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
|
|
||||||
|
|
||||||
#### Scenario: 含斜杠的统一模型 ID 路径
|
|
||||||
|
|
||||||
- **WHEN** 路径为 `/v1/models/anthropic/claude-3-opus`
|
|
||||||
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
|
||||||
|
|
||||||
#### Scenario: 含多段斜杠的统一模型 ID 路径
|
|
||||||
|
|
||||||
- **WHEN** 路径为 `/v1/models/azure/deployments/gpt-4`
|
|
||||||
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
|
||||||
|
|
||||||
#### Scenario: 模型列表路径不受影响
|
|
||||||
|
|
||||||
- **WHEN** 路径为 `/v1/models`
|
|
||||||
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
|
|
||||||
|
|
||||||
### Requirement: 提取统一模型 ID
|
|
||||||
|
|
||||||
Anthropic 适配器 SHALL 从路径中提取统一模型 ID。
|
|
||||||
|
|
||||||
#### Scenario: 标准路径提取
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/anthropic/claude-3-opus")`
|
|
||||||
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
|
|
||||||
|
|
||||||
#### Scenario: 复杂路径提取
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
|
|
||||||
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
|
|
||||||
|
|
||||||
#### Scenario: 非模型详情路径
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
|
|
||||||
- **THEN** SHALL 返回错误
|
|
||||||
|
|
||||||
### Requirement: 从请求体提取 model
|
|
||||||
|
|
||||||
Anthropic 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
|
|
||||||
|
|
||||||
#### Scenario: Chat 请求提取 model
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"anthropic/claude-3-opus","messages":[...]}`
|
|
||||||
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
|
|
||||||
|
|
||||||
#### Scenario: 无 model 字段
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段
|
|
||||||
- **THEN** SHALL 返回空字符串,不返回错误
|
|
||||||
|
|
||||||
### Requirement: 最小化改写请求体 model 字段
|
|
||||||
|
|
||||||
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
|
|
||||||
|
|
||||||
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
|
|
||||||
|
|
||||||
- **WHEN** 调用 `RewriteRequestModelName(body, "claude-3-opus-20240229", InterfaceTypeChat)`
|
|
||||||
- **THEN** SHALL 将请求体中 model 字段替换为 `"claude-3-opus-20240229"`,其余字段原样保留
|
|
||||||
|
|
||||||
### Requirement: 最小化改写响应体 model 字段
|
|
||||||
|
|
||||||
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
|
|
||||||
|
|
||||||
#### Scenario: 响应体 model 改写(上游名 → 统一 ID)
|
|
||||||
|
|
||||||
- **WHEN** 调用 `RewriteResponseModelName(body, "anthropic/claude-3-opus", InterfaceTypeChat)`
|
|
||||||
- **THEN** SHALL 将响应体中 model 字段替换为 `"anthropic/claude-3-opus"`,其余字段原样保留
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: 模型详情路径识别
|
|
||||||
|
|
||||||
OpenAI 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
|
|
||||||
|
|
||||||
#### Scenario: 含斜杠的统一模型 ID 路径
|
|
||||||
|
|
||||||
- **WHEN** 路径为 `/v1/models/openai/gpt-4`
|
|
||||||
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
|
||||||
|
|
||||||
#### Scenario: 含多段斜杠的统一模型 ID 路径
|
|
||||||
|
|
||||||
- **WHEN** 路径为 `/v1/models/azure/accounts/org-123/models/gpt-4`
|
|
||||||
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
|
||||||
|
|
||||||
#### Scenario: 模型列表路径不受影响
|
|
||||||
|
|
||||||
- **WHEN** 路径为 `/v1/models`
|
|
||||||
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
|
|
||||||
|
|
||||||
### Requirement: 提取统一模型 ID
|
|
||||||
|
|
||||||
OpenAI 适配器 SHALL 从路径中提取统一模型 ID。
|
|
||||||
|
|
||||||
#### Scenario: 标准路径提取
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/openai/gpt-4")`
|
|
||||||
- **THEN** SHALL 返回 `"openai/gpt-4"`
|
|
||||||
|
|
||||||
#### Scenario: 复杂路径提取
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
|
|
||||||
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
|
|
||||||
|
|
||||||
#### Scenario: 非模型详情路径
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
|
|
||||||
- **THEN** SHALL 返回错误
|
|
||||||
|
|
||||||
### Requirement: 从请求体提取 model
|
|
||||||
|
|
||||||
OpenAI 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
|
|
||||||
|
|
||||||
#### Scenario: Chat 请求提取 model
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...]}`
|
|
||||||
- **THEN** SHALL 返回 `"openai/gpt-4"`
|
|
||||||
|
|
||||||
#### Scenario: Embedding 请求提取 model
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeEmbeddings)`,body 为 `{"model":"openai/text-embedding-3","input":"text"}`
|
|
||||||
- **THEN** SHALL 返回 `"openai/text-embedding-3"`
|
|
||||||
|
|
||||||
#### Scenario: 无 model 字段
|
|
||||||
|
|
||||||
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段
|
|
||||||
- **THEN** SHALL 返回空字符串,不返回错误
|
|
||||||
|
|
||||||
### Requirement: 最小化改写请求体 model 字段
|
|
||||||
|
|
||||||
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
|
|
||||||
|
|
||||||
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
|
|
||||||
|
|
||||||
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...],"some_param":"value"}`
|
|
||||||
- **THEN** SHALL 返回 `{"model":"gpt-4","messages":[...],"some_param":"value"}`
|
|
||||||
- **THEN** 除 model 外的字段 SHALL 保持原始 bytes 不变
|
|
||||||
|
|
||||||
#### Scenario: 不同 InterfaceType 的请求改写
|
|
||||||
|
|
||||||
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeEmbeddings)`
|
|
||||||
- **THEN** SHALL 按 Embedding 接口的请求体 model 字段位置进行改写
|
|
||||||
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeRerank)`
|
|
||||||
- **THEN** SHALL 按 Rerank 接口的请求体 model 字段位置进行改写
|
|
||||||
|
|
||||||
### Requirement: 最小化改写响应体 model 字段
|
|
||||||
|
|
||||||
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
|
|
||||||
|
|
||||||
#### Scenario: 响应体 model 改写(上游名 → 统一 ID)
|
|
||||||
|
|
||||||
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeChat)`,body 为上游 Chat 响应
|
|
||||||
- **THEN** SHALL 将 model 字段替换为 `"openai/gpt-4"`,其余字段原样保留
|
|
||||||
|
|
||||||
#### Scenario: 不同 InterfaceType 的响应改写
|
|
||||||
|
|
||||||
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeEmbeddings)`
|
|
||||||
- **THEN** SHALL 按 Embedding 接口的响应体 model 字段位置进行改写
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: 创建供应商配置
|
|
||||||
|
|
||||||
网关 SHALL 允许通过管理 API 创建新的供应商配置。
|
|
||||||
|
|
||||||
#### Scenario: 使用有效数据创建供应商
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/providers` 发送 POST 请求,携带有效的供应商数据(id, name, api_key, base_url, protocol)
|
|
||||||
- **THEN** 网关 SHALL 在数据库中创建新的供应商记录
|
|
||||||
- **THEN** 网关 SHALL 返回创建的供应商,状态码为 201
|
|
||||||
- **THEN** 供应商 SHALL 默认启用
|
|
||||||
- **THEN** protocol 字段 SHALL 默认为 "openai"
|
|
||||||
|
|
||||||
#### Scenario: 使用重复 ID 创建供应商
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/providers` 发送 POST 请求,携带已存在的 ID
|
|
||||||
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
|
|
||||||
|
|
||||||
#### Scenario: 创建供应商时缺少必需字段
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/providers` 发送 POST 请求,缺少必需字段(id, name, api_key 或 base_url)
|
|
||||||
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
|
||||||
- **THEN** 错误 SHALL 指示缺少哪些字段
|
|
||||||
|
|
||||||
#### Scenario: 创建供应商时 ID 包含非法字符
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/providers` 发送 POST 请求,id 包含非 `[a-zA-Z0-9_]` 字符
|
|
||||||
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
|
||||||
- **THEN** 错误 SHALL 指示 id 仅允许字母、数字、下划线
|
|
||||||
|
|
||||||
#### Scenario: 创建供应商时 ID 过长
|
|
||||||
|
|
||||||
- **WHEN** 向 `/api/providers` 发送 POST 请求,id 长度超过 64
|
|
||||||
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: 供应商 ID 校验
|
|
||||||
|
|
||||||
创建供应商时,SHALL 对 `id` 字段进行字符集校验。
|
|
||||||
|
|
||||||
#### Scenario: 合法字符集
|
|
||||||
|
|
||||||
- **WHEN** 创建供应商,id 仅包含 `[a-zA-Z0-9_]` 字符
|
|
||||||
- **THEN** SHALL 校验通过
|
|
||||||
|
|
||||||
#### Scenario: 非法字符
|
|
||||||
|
|
||||||
- **WHEN** 创建供应商,id 包含 `-`、`.`、`/`、空格、中文等非 `[a-zA-Z0-9_]` 字符
|
|
||||||
- **THEN** SHALL 返回 400 错误
|
|
||||||
|
|
||||||
#### Scenario: 长度限制
|
|
||||||
|
|
||||||
- **WHEN** 创建供应商,id 长度超过 64
|
|
||||||
- **THEN** SHALL 返回 400 错误
|
|
||||||
|
|
||||||
### Requirement: 模型创建校验
|
|
||||||
|
|
||||||
创建模型时,SHALL 对 `provider_id` + `model_name` 进行联合唯一性校验。
|
|
||||||
|
|
||||||
#### Scenario: 正常创建
|
|
||||||
|
|
||||||
- **WHEN** 创建模型,provider_id 存在且 provider_id + model_name 组合唯一
|
|
||||||
- **THEN** SHALL 校验通过
|
|
||||||
|
|
||||||
#### Scenario: 联合唯一冲突
|
|
||||||
|
|
||||||
- **WHEN** 创建模型,provider_id + model_name 组合已存在
|
|
||||||
- **THEN** SHALL 返回 409 错误
|
|
||||||
|
|
||||||
#### Scenario: model_name 为空
|
|
||||||
|
|
||||||
- **WHEN** 创建模型,未提供 model_name
|
|
||||||
- **THEN** SHALL 返回 400 错误
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
## MODIFIED Requirements
|
|
||||||
|
|
||||||
### Requirement: 代理请求路由
|
|
||||||
|
|
||||||
ProxyHandler SHALL 使用统一模型 ID 路由所有代理请求。
|
|
||||||
|
|
||||||
#### Scenario: 提取统一模型 ID
|
|
||||||
|
|
||||||
- **WHEN** 收到 Chat、Embeddings 或 Rerank 接口的 POST 请求(含请求体)
|
|
||||||
- **THEN** SHALL 调用客户端协议 adapter 的 `ExtractModelName(body, ifaceType)` 提取 model 值
|
|
||||||
- **THEN** SHALL 调用 `ParseUnifiedModelID` 解析得到 providerID 和 modelName
|
|
||||||
- **THEN** SHALL 调用 `RoutingService.RouteByModelName(providerID, modelName)` 路由
|
|
||||||
|
|
||||||
#### Scenario: GET 请求或无请求体
|
|
||||||
|
|
||||||
- **WHEN** 收到 GET 请求或请求体为空
|
|
||||||
- **THEN** SHALL 返回错误响应,状态码为 400,提示缺少 model 字段
|
|
||||||
|
|
||||||
#### Scenario: 无效的统一模型 ID
|
|
||||||
|
|
||||||
- **WHEN** 请求体中 `model` 字段不是有效的统一模型 ID 格式
|
|
||||||
- **THEN** SHALL 返回错误响应,状态码为 400
|
|
||||||
|
|
||||||
#### Scenario: 模型不存在
|
|
||||||
|
|
||||||
- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的 provider_id + model_name 组合
|
|
||||||
- **THEN** SHALL 返回错误响应,状态码为 404
|
|
||||||
|
|
||||||
#### Scenario: 模型已禁用
|
|
||||||
|
|
||||||
- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false
|
|
||||||
- **THEN** SHALL 返回错误响应,状态码为 404
|
|
||||||
|
|
||||||
#### Scenario: 供应商已禁用
|
|
||||||
|
|
||||||
- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false
|
|
||||||
- **THEN** SHALL 返回错误响应,状态码为 404
|
|
||||||
|
|
||||||
### Requirement: 同协议 Smart Passthrough
|
|
||||||
|
|
||||||
当客户端协议与供应商协议相同时,ProxyHandler SHALL 使用 Smart Passthrough 处理 Chat、Embedding、Rerank 请求。
|
|
||||||
|
|
||||||
#### Scenario: 同协议非流式请求
|
|
||||||
|
|
||||||
- **WHEN** 客户端协议 == 供应商协议,且为非流式请求
|
|
||||||
- **THEN** SHALL 调用 adapter 的 `RewriteRequestModelName(body, modelName, ifaceType)` 将请求体中 model 从统一 ID 改写为上游模型名
|
|
||||||
- **THEN** SHALL 构建 URL 和 Headers(同当前透传逻辑)
|
|
||||||
- **THEN** SHALL 发送改写后的请求体到上游
|
|
||||||
- **THEN** SHALL 调用 adapter 的 `RewriteResponseModelName(resp.Body, unifiedModelID, ifaceType)` 将响应中 model 从上游名改写为统一 ID
|
|
||||||
- **THEN** SHALL NOT 对 body 做全量 decode → encode,保持未改写字段的原始 bytes
|
|
||||||
|
|
||||||
#### Scenario: 同协议流式请求
|
|
||||||
|
|
||||||
- **WHEN** 客户端协议 == 供应商协议,且为流式请求
|
|
||||||
- **THEN** SHALL 对请求体做 `RewriteRequestModelName` 改写 model 字段
|
|
||||||
- **THEN** SHALL 逐 SSE chunk 调用 `RewriteResponseModelName` 改写响应中 model 字段
|
|
||||||
- **THEN** SHALL NOT 对 chunk 做全量 decode → encode
|
|
||||||
|
|
||||||
#### Scenario: Smart Passthrough 保真性
|
|
||||||
|
|
||||||
- **WHEN** 客户端发送含未知参数的请求(如 `{"model":"openai/gpt-4","some_new_param":"value"}`)
|
|
||||||
- **THEN** 上游 SHALL 收到 `{"model":"gpt-4","some_new_param":"value"}`
|
|
||||||
- **THEN** `some_new_param` SHALL 保持原始值不变,不丢失、不改变类型
|
|
||||||
|
|
||||||
### Requirement: 跨协议完整转换
|
|
||||||
|
|
||||||
当客户端协议与供应商协议不同时,ProxyHandler SHALL 使用全量转换路径。
|
|
||||||
|
|
||||||
#### Scenario: 跨协议非流式请求
|
|
||||||
|
|
||||||
- **WHEN** 客户端协议 != 供应商协议
|
|
||||||
- **THEN** SHALL 走 `ConvertHttpRequest` 全量转换,encoder 中 provider.ModelName 覆盖 model
|
|
||||||
- **THEN** SHALL 走 `ConvertHttpResponse` 全量转换,modelOverride 参数覆写 canonical.Model
|
|
||||||
|
|
||||||
#### Scenario: 跨协议流式请求
|
|
||||||
|
|
||||||
- **WHEN** 客户端协议 != 供应商协议,且为流式请求
|
|
||||||
- **THEN** SHALL 走 `CreateStreamConverter` 全量转换,modelOverride 参数覆写流式 canonical 事件中的 Model
|
|
||||||
|
|
||||||
### Requirement: 模型列表本地聚合
|
|
||||||
|
|
||||||
ProxyHandler SHALL 从数据库聚合返回模型列表,不再透传上游。
|
|
||||||
|
|
||||||
#### Scenario: GET /v1/models
|
|
||||||
|
|
||||||
- **WHEN** 收到 `GET /{protocol}/v1/models` 请求
|
|
||||||
- **THEN** SHALL 从数据库查询所有 enabled 的模型(关联 enabled 的供应商)
|
|
||||||
- **THEN** SHALL 组装 `CanonicalModelList`,每个模型的 ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id
|
|
||||||
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
|
|
||||||
- **THEN** SHALL NOT 请求上游供应商
|
|
||||||
|
|
||||||
#### Scenario: 无可用模型
|
|
||||||
|
|
||||||
- **WHEN** 数据库中没有 enabled 的模型
|
|
||||||
- **THEN** SHALL 返回空列表
|
|
||||||
|
|
||||||
### Requirement: 模型详情本地查询
|
|
||||||
|
|
||||||
ProxyHandler SHALL 从数据库查询返回模型详情,不再透传上游。
|
|
||||||
|
|
||||||
#### Scenario: GET /v1/models/{unified_id}
|
|
||||||
|
|
||||||
- **WHEN** 收到 `GET /{protocol}/v1/models/{provider_id}/{model_name}` 请求
|
|
||||||
- **THEN** SHALL 调用 adapter 的 `ExtractUnifiedModelID` 提取统一模型 ID
|
|
||||||
- **THEN** SHALL 解析统一模型 ID 得到 providerID 和 modelName
|
|
||||||
- **THEN** SHALL 从数据库查询对应的模型和供应商
|
|
||||||
- **THEN** SHALL 组装 `CanonicalModelInfo`,ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id
|
|
||||||
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
|
|
||||||
- **THEN** SHALL NOT 请求上游供应商
|
|
||||||
|
|
||||||
#### Scenario: 模型详情不存在
|
|
||||||
|
|
||||||
- **WHEN** 统一模型 ID 对应的模型不存在或已禁用
|
|
||||||
- **THEN** SHALL 返回错误响应,状态码为 404
|
|
||||||
|
|
||||||
### Requirement: 统计记录
|
|
||||||
|
|
||||||
ProxyHandler SHALL 使用 providerID 和 modelName 记录使用统计。
|
|
||||||
|
|
||||||
#### Scenario: 异步记录统计
|
|
||||||
|
|
||||||
- **WHEN** 代理请求成功完成
|
|
||||||
- **THEN** SHALL 异步调用 `StatsService.Record(providerID, modelName)`
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
## ADDED Requirements
|
|
||||||
|
|
||||||
### Requirement: 使用统计记录统一模型标识
|
|
||||||
|
|
||||||
系统 SHALL 使用 providerID 和 modelName(上游模型名)记录使用统计。
|
|
||||||
|
|
||||||
#### Scenario: 代理请求统计记录
|
|
||||||
|
|
||||||
- **WHEN** 代理请求成功完成
|
|
||||||
- **THEN** SHALL 记录 provider_id 和 model_name 到 usage_stats 表(参数来自路由结果)
|
|
||||||
- **THEN** SHALL 异步执行,不阻塞响应
|
|
||||||
|
|
||||||
#### Scenario: 查询统计
|
|
||||||
|
|
||||||
- **WHEN** 查询统计数据
|
|
||||||
- **THEN** 支持按 provider_id 和 model_name 过滤
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
## 1. 数据库迁移
|
|
||||||
|
|
||||||
- [ ] 1.1 新增迁移脚本:DROP 旧 models 表 + CREATE 新 models 表(id UUID PK, provider_id, model_name, enabled, created_at),UNIQUE(provider_id, model_name)
|
|
||||||
- [ ] 1.2 更新 config/models.go:Model 结构体适配(id 改为 UUID 自动生成,model_name 保持不变)
|
|
||||||
- [ ] 1.3 编写迁移脚本测试
|
|
||||||
|
|
||||||
## 2. 统一模型 ID 工具包
|
|
||||||
|
|
||||||
- [ ] 2.1 新增 pkg/modelid/model_id.go:实现 ParseUnifiedModelID、FormatUnifiedModelID、ValidateProviderID、IsValidUnifiedModelID
|
|
||||||
- [ ] 2.2 新增 pkg/modelid/model_id_test.go:覆盖标准格式、含斜杠 model_name、空字符串、非法字符等边界情况
|
|
||||||
|
|
||||||
## 3. Domain 层适配
|
|
||||||
|
|
||||||
- [ ] 3.1 修改 domain/model.go:Model 结构体字段适配,新增 UnifiedModelID() 方法
|
|
||||||
- [ ] 3.2 修改 domain/route.go:RouteResult 适配新字段
|
|
||||||
|
|
||||||
## 4. Repository 层适配
|
|
||||||
|
|
||||||
- [ ] 4.1 修改 repository/model_repo.go:接口变更 — GetByModelName 改为 FindByProviderAndModelName,新增 ListEnabled
|
|
||||||
- [ ] 4.2 修改 repository/model_repo_impl.go:实现 FindByProviderAndModelName(WHERE provider_id=? AND model_name=?)、ListEnabled(JOIN providers WHERE enabled)
|
|
||||||
- [ ] 4.3 编写 repository 层测试
|
|
||||||
|
|
||||||
## 5. Service 层适配
|
|
||||||
|
|
||||||
- [ ] 5.1 修改 service/routing_service.go:Route 接口改为 RouteByModelName(providerID, modelName string)
|
|
||||||
- [ ] 5.2 修改 service/routing_service_impl.go:调用 FindByProviderAndModelName 替代 GetByModelName
|
|
||||||
- [ ] 5.3 修改 service/model_service.go:Create 生成 UUID、新增联合唯一校验方法
|
|
||||||
- [ ] 5.4 修改 service/model_service_impl.go:实现联合唯一校验、UUID 生成
|
|
||||||
- [ ] 5.5 修改 service/provider_service_impl.go:Create 时调用 ValidateProviderID 校验 ID 字符集
|
|
||||||
- [ ] 5.6 编写 service 层测试
|
|
||||||
|
|
||||||
## 6. Conversion 层适配
|
|
||||||
|
|
||||||
- [ ] 6.1 修改 conversion/adapter.go:ProtocolAdapter 接口新增 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName 四个方法
|
|
||||||
- [ ] 6.2 修改 conversion/engine.go:ConvertHttpResponse 新增 modelOverride 参数(跨协议场景),各 convert*ResponseBody 中覆写 canonical Model;CreateStreamConverter 新增 modelOverride 参数
|
|
||||||
- [ ] 6.3 修改 conversion/openai/adapter.go:实现 ExtractUnifiedModelID、ExtractModelName(按 ifaceType 提取 model)、RewriteRequestModelName 和 RewriteResponseModelName(json.RawMessage 最小化改写,按 ifaceType 定位 model 字段,请求/响应独立实现),修改 isModelInfoPath 允许 suffix 含 "/"
|
|
||||||
- [ ] 6.4 修改 conversion/anthropic/adapter.go:实现 ExtractUnifiedModelID、ExtractModelName、RewriteRequestModelName、RewriteResponseModelName,修改 isModelInfoPath 允许 suffix 含 "/"
|
|
||||||
- [ ] 6.5 编写 conversion 层测试:ExtractUnifiedModelID、ExtractModelName 各 ifaceType、RewriteRequestModelName/RewriteResponseModelName 保真性(含未知参数不丢失)、isModelInfoPath 含斜杠路径、modelOverride 覆写
|
|
||||||
|
|
||||||
## 7. Handler 层改造
|
|
||||||
|
|
||||||
- [ ] 7.1 修改 handler/proxy_handler.go:HandleProxy 按接口类型分发 — Models/ModelInfo 本地聚合;Chat/Embed/Rerank 用 adapter.ExtractModelName 提取统一 ID 路由,同协议走 Smart Passthrough(adapter.RewriteRequestModelName 改写请求、adapter.RewriteResponseModelName 改写响应),跨协议走全量转换(modelOverride);删除 forwardPassthrough 和硬编码的 extractModelName
|
|
||||||
- [ ] 7.2 修改 handler/model_handler.go:请求体字段适配(移除 id 输入、保留 provider_id 和 model_name),响应新增 unified_id,Create 使用 UUID
|
|
||||||
- [ ] 7.3 修改 handler/provider_handler.go:CreateProvider 校验 ID 字符集
|
|
||||||
- [ ] 7.4 编写 handler 层测试:统一模型 ID 路由、同协议 Smart Passthrough 保真性、跨协议 modelOverride、Models 聚合、ModelInfo 查询、流式场景 model 覆写、provider ID 校验
|
|
||||||
|
|
||||||
## 8. 路由注册适配
|
|
||||||
|
|
||||||
- [ ] 8.1 修改 cmd/server/main.go:setupRoutes 适配 handler 签名变更,传递新增依赖
|
|
||||||
|
|
||||||
## 9. 文档更新
|
|
||||||
|
|
||||||
- [ ] 9.1 按需更新 README.md:同步 models 表结构、API 接口字段、统一模型 ID 格式、Smart Passthrough 策略等变更说明
|
|
||||||
@@ -277,4 +277,51 @@ ErrorCode SHALL 包含:INVALID_INPUT、MISSING_REQUIRED_FIELD、INCOMPATIBLE_F
|
|||||||
- **WHEN** Adapter 调用 buildHeaders(provider)
|
- **WHEN** Adapter 调用 buildHeaders(provider)
|
||||||
- **THEN** SHALL 从 provider.api_key 提取认证信息
|
- **THEN** SHALL 从 provider.api_key 提取认证信息
|
||||||
- **THEN** SHALL 从 provider.adapter_config 提取协议专属配置
|
- **THEN** SHALL 从 provider.adapter_config 提取协议专属配置
|
||||||
- **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段
|
- **THEN** SHALL 使用 provider.model_name 覆盖请求中的 model 字段
|
||||||
|
### Requirement: 跨协议响应转换支持 model 覆写
|
||||||
|
|
||||||
|
ConversionEngine SHALL 在跨协议响应转换时支持 model 字段覆写。
|
||||||
|
|
||||||
|
#### Scenario: ConvertHttpResponse 接收 modelOverride 参数
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ConvertHttpResponse` 时传入 `modelOverride` 参数(跨协议场景,非空字符串)
|
||||||
|
- **THEN** SHALL 在解码上游响应到 canonical 后,将 `Model` 字段设为 `modelOverride`
|
||||||
|
- **THEN** SHALL 使用覆写后的 canonical 编码为客户端协议格式
|
||||||
|
|
||||||
|
#### Scenario: modelOverride 为空
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ConvertHttpResponse` 时 `modelOverride` 为空字符串
|
||||||
|
- **THEN** SHALL NOT 覆写 canonical 的 Model 字段,保持上游原始值
|
||||||
|
|
||||||
|
#### Scenario: Chat 响应 model 覆写
|
||||||
|
|
||||||
|
- **WHEN** 跨协议转换 Chat 类型响应且 `modelOverride` 非空
|
||||||
|
- **THEN** `CanonicalResponse.Model` SHALL 被设为 `modelOverride`
|
||||||
|
|
||||||
|
#### Scenario: Embedding 响应 model 覆写
|
||||||
|
|
||||||
|
- **WHEN** 跨协议转换 Embedding 类型响应且 `modelOverride` 非空
|
||||||
|
- **THEN** `CanonicalEmbeddingResponse.Model` SHALL 被设为 `modelOverride`
|
||||||
|
|
||||||
|
#### Scenario: Rerank 响应 model 覆写
|
||||||
|
|
||||||
|
- **WHEN** 跨协议转换 Rerank 类型响应且 `modelOverride` 非空
|
||||||
|
- **THEN** `CanonicalRerankResponse.Model` SHALL 被设为 `modelOverride`
|
||||||
|
|
||||||
|
### Requirement: 跨协议流式转换支持 model 覆写
|
||||||
|
|
||||||
|
ConversionEngine SHALL 在跨协议流式转换时支持 model 字段覆写。
|
||||||
|
|
||||||
|
#### Scenario: CreateStreamConverter 接收 modelOverride 参数
|
||||||
|
|
||||||
|
- **WHEN** 调用 `CreateStreamConverter` 时传入 `modelOverride` 参数(跨协议场景)
|
||||||
|
- **THEN** SHALL 在流式 canonical 事件中将 `Model` 字段设为 `modelOverride`
|
||||||
|
|
||||||
|
### Requirement: TargetProvider 字段语义
|
||||||
|
|
||||||
|
TargetProvider 的 ModelName 字段 SHALL 存储上游供应商的模型名称(即 `model_name` 字段值),语义保持不变。
|
||||||
|
|
||||||
|
#### Scenario: encoder 使用 TargetProvider.ModelName
|
||||||
|
|
||||||
|
- **WHEN** 协议适配器编码请求时
|
||||||
|
- **THEN** SHALL 使用 `TargetProvider.ModelName` 作为发给上游的 `model` 字段值(值为路由结果中的 model_name)
|
||||||
|
|||||||
@@ -28,9 +28,10 @@
|
|||||||
#### Scenario: 初始迁移文件
|
#### Scenario: 初始迁移文件
|
||||||
|
|
||||||
- **WHEN** 创建初始迁移
|
- **WHEN** 创建初始迁移
|
||||||
- **THEN** SHALL 创建 001_initial_schema.sql
|
- **THEN** SHALL 创建单个初始迁移文件(如 `20260421000001_initial_schema.sql`)
|
||||||
- **THEN** SHALL 包含 providers、models、usage_stats 表的创建语句
|
- **THEN** SHALL 包含 providers、models、usage_stats 表的创建语句
|
||||||
- **THEN** SHALL 包含外键约束
|
- **THEN** SHALL 包含外键约束
|
||||||
|
- **THEN** SHALL 包含索引创建语句
|
||||||
|
|
||||||
#### Scenario: Up 迁移
|
#### Scenario: Up 迁移
|
||||||
|
|
||||||
@@ -42,25 +43,19 @@
|
|||||||
#### Scenario: Down 迁移
|
#### Scenario: Down 迁移
|
||||||
|
|
||||||
- **WHEN** 执行 down 迁移
|
- **WHEN** 执行 down 迁移
|
||||||
- **THEN** SHALL 删除所有表
|
- **THEN** SHALL 删除所有表和索引
|
||||||
- **THEN** SHALL 按正确顺序删除(避免外键约束错误)
|
- **THEN** SHALL 按正确顺序删除(避免外键约束错误)
|
||||||
|
|
||||||
### Requirement: 添加索引迁移
|
### Requirement: models 表 schema 变更
|
||||||
|
|
||||||
系统 SHALL 创建索引迁移。
|
系统 SHALL 在初始迁移脚本中直接创建新的 models 表结构(服务未上线,无需考虑数据迁移,迁移脚本已合并为单个初始迁移文件)。
|
||||||
|
|
||||||
#### Scenario: 索引迁移文件
|
#### Scenario: 初始迁移 models 表结构
|
||||||
|
|
||||||
- **WHEN** 创建索引迁移
|
- **WHEN** 执行迁移
|
||||||
- **THEN** SHALL 创建 002_add_indexes.sql
|
- **THEN** SHALL CREATE models 表,包含字段:id(TEXT PRIMARY KEY,存储 UUID 字符串)、provider_id(TEXT NOT NULL)、model_name(TEXT NOT NULL)、enabled(INTEGER DEFAULT 1)、created_at(DATETIME)
|
||||||
- **THEN** SHALL 为常用查询字段添加索引
|
- **THEN** SHALL 存在 UNIQUE(provider_id, model_name) 约束
|
||||||
|
- **THEN** SHALL 存在 FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
||||||
#### Scenario: 索引定义
|
|
||||||
|
|
||||||
- **WHEN** 添加索引
|
|
||||||
- **THEN** SHALL 为 models(provider_id) 添加索引
|
|
||||||
- **THEN** SHALL 为 models(model_name) 添加索引
|
|
||||||
- **THEN** SHALL 为 usage_stats(provider_id, model_name, date) 添加复合索引
|
|
||||||
|
|
||||||
### Requirement: 迁移命令集成
|
### Requirement: 迁移命令集成
|
||||||
|
|
||||||
|
|||||||
209
openspec/specs/error-responses/spec.md
Normal file
209
openspec/specs/error-responses/spec.md
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# Error Responses
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
|
||||||
|
定义系统统一的错误响应格式和各类错误场景,确保客户端能够一致地处理错误。
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
### Requirement: 统一错误响应格式
|
||||||
|
|
||||||
|
系统 SHALL 使用统一的错误响应格式。
|
||||||
|
|
||||||
|
#### Scenario: 标准错误格式
|
||||||
|
|
||||||
|
- **WHEN** 返回错误响应
|
||||||
|
- **THEN** SHALL 使用以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "错误描述",
|
||||||
|
"code": "ERROR_CODE"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- **THEN** `error` 字段 SHALL 包含人类可读的错误描述
|
||||||
|
- **THEN** `code` 字段 SHALL 包含机器可读的错误代码(可选)
|
||||||
|
|
||||||
|
### Requirement: provider_id 校验错误
|
||||||
|
|
||||||
|
系统 SHALL 对 provider_id 校验错误返回明确的错误信息。
|
||||||
|
|
||||||
|
#### Scenario: provider_id 包含非法字符
|
||||||
|
|
||||||
|
- **WHEN** 创建或更新供应商时,provider_id 包含非 `[a-zA-Z0-9_]` 字符
|
||||||
|
- **THEN** SHALL 返回 HTTP 400 Bad Request
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "供应商 ID 仅允许字母、数字、下划线",
|
||||||
|
"code": "INVALID_PROVIDER_ID"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: provider_id 长度超限
|
||||||
|
|
||||||
|
- **WHEN** 创建或更新供应商时,provider_id 长度超过 64
|
||||||
|
- **THEN** SHALL 返回 HTTP 400 Bad Request
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "供应商 ID 长度不能超过 64 个字符",
|
||||||
|
"code": "INVALID_PROVIDER_ID"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Requirement: 联合唯一约束冲突错误
|
||||||
|
|
||||||
|
系统 SHALL 对联合唯一约束冲突返回明确的错误信息。
|
||||||
|
|
||||||
|
#### Scenario: 创建模型时 provider_id + model_name 组合已存在
|
||||||
|
|
||||||
|
- **WHEN** 创建模型时,provider_id + model_name 组合已存在
|
||||||
|
- **THEN** SHALL 返回 HTTP 409 Conflict
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "同一供应商下模型名称已存在",
|
||||||
|
"code": "duplicate_model"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: 更新模型时导致 provider_id + model_name 组合冲突
|
||||||
|
|
||||||
|
- **WHEN** 更新模型时,修改 provider_id 或 model_name 导致与已有记录冲突
|
||||||
|
- **THEN** SHALL 返回 HTTP 409 Conflict
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "同一供应商下模型名称已存在",
|
||||||
|
"code": "duplicate_model"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Requirement: 资源不存在错误
|
||||||
|
|
||||||
|
系统 SHALL 对资源不存在返回明确的错误信息。
|
||||||
|
|
||||||
|
#### Scenario: 模型不存在
|
||||||
|
|
||||||
|
- **WHEN** 查询或操作不存在的模型
|
||||||
|
- **THEN** SHALL 返回 HTTP 404 Not Found
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "模型未找到"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: 供应商不存在
|
||||||
|
|
||||||
|
- **WHEN** 创建模型时指定的供应商不存在
|
||||||
|
- **THEN** SHALL 返回 HTTP 400 Bad Request
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "供应商不存在"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Requirement: 统一模型 ID 格式错误
|
||||||
|
|
||||||
|
系统 SHALL 对统一模型 ID 格式错误返回明确的错误信息。
|
||||||
|
|
||||||
|
#### Scenario: 统一模型 ID 格式无效
|
||||||
|
|
||||||
|
- **WHEN** 代理请求中的 model 字段不是有效的统一模型 ID 格式
|
||||||
|
- **THEN** 请求 SHALL 走 forwardPassthrough 透传到上游(兼容未适配的客户端)
|
||||||
|
- **THEN** 不返回错误,保持与上游的兼容性
|
||||||
|
|
||||||
|
**设计理由:**
|
||||||
|
- 统一模型 ID 是 BREAKING CHANGE,部分旧客户端可能仍使用原始模型名
|
||||||
|
- 透传策略允许上游自行判断并返回错误(如 404 model not found)
|
||||||
|
- 网关作为透明代理,不应拦截所有格式非法的请求
|
||||||
|
|
||||||
|
#### Scenario: 统一模型 ID 格式有效但对应模型不存在
|
||||||
|
|
||||||
|
- **WHEN** 代理请求中的 model 字段是有效的统一模型 ID 格式(含 `/`),但数据库中找不到对应的模型
|
||||||
|
- **THEN** SHALL 返回 HTTP 404 Not Found
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "模型未找到"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: 统一模型 ID 对应的模型不存在
|
||||||
|
|
||||||
|
- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的模型
|
||||||
|
- **THEN** SHALL 返回 HTTP 404 Not Found
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "模型未找到"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: 统一模型 ID 对应的模型已禁用
|
||||||
|
|
||||||
|
- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false
|
||||||
|
- **THEN** SHALL 返回 HTTP 404 Not Found
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "模型未找到"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: 统一模型 ID 对应的供应商已禁用
|
||||||
|
|
||||||
|
- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false
|
||||||
|
- **THEN** SHALL 返回 HTTP 404 Not Found
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "模型未找到"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Requirement: JSON 格式错误
|
||||||
|
|
||||||
|
系统 SHALL 对请求体 JSON 格式错误返回明确的错误信息。
|
||||||
|
|
||||||
|
#### Scenario: 请求体 JSON 格式错误
|
||||||
|
|
||||||
|
- **WHEN** 代理请求的请求体不是有效的 JSON 格式
|
||||||
|
- **THEN** SHALL 返回 HTTP 400 Bad Request
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "请求体 JSON 格式错误",
|
||||||
|
"code": "INVALID_JSON"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: Smart Passthrough 时请求体 JSON 格式错误
|
||||||
|
|
||||||
|
- **WHEN** 同协议 Smart Passthrough 场景下,请求体 JSON 格式不正确
|
||||||
|
- **THEN** SHALL 返回 HTTP 400 Bad Request
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "请求体 JSON 格式错误",
|
||||||
|
"code": "INVALID_JSON"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Requirement: 不可变字段错误
|
||||||
|
|
||||||
|
系统 SHALL 对尝试修改不可变字段返回明确的错误信息。
|
||||||
|
|
||||||
|
#### Scenario: 尝试修改供应商 ID
|
||||||
|
|
||||||
|
- **WHEN** 更新供应商时,请求体中包含 `id` 字段
|
||||||
|
- **THEN** SHALL 返回 HTTP 400 Bad Request
|
||||||
|
- **THEN** SHALL 返回以下 JSON 格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "供应商 ID 不允许修改",
|
||||||
|
"code": "IMMUTABLE_FIELD"
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -10,10 +10,12 @@
|
|||||||
|
|
||||||
#### Scenario: 使用有效数据创建模型
|
#### Scenario: 使用有效数据创建模型
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models` 发送 POST 请求,携带有效的模型数据(id, provider_id, model_name)
|
- **WHEN** 向 `/api/models` 发送 POST 请求,携带有效的模型数据(provider_id, model_name),不提供 id 字段
|
||||||
|
- **THEN** 网关 SHALL 自动生成 UUID 作为模型 id
|
||||||
- **THEN** 网关 SHALL 在数据库中创建新的模型记录
|
- **THEN** 网关 SHALL 在数据库中创建新的模型记录
|
||||||
- **THEN** 网关 SHALL 返回创建的模型,状态码为 201
|
- **THEN** 网关 SHALL 返回创建的模型,状态码为 201
|
||||||
- **THEN** 模型 SHALL 默认启用
|
- **THEN** 模型 SHALL 默认启用
|
||||||
|
- **THEN** 返回的模型 SHALL 包含 `unified_id` 字段,值为 `{provider_id}/{model_name}`
|
||||||
|
|
||||||
#### Scenario: 使用不存在的供应商创建模型
|
#### Scenario: 使用不存在的供应商创建模型
|
||||||
|
|
||||||
@@ -21,7 +23,11 @@
|
|||||||
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
||||||
- **THEN** 错误 SHALL 指示供应商不存在
|
- **THEN** 错误 SHALL 指示供应商不存在
|
||||||
|
|
||||||
**变更说明:** handler 通过 ModelService 调用,数据访问通过 ModelRepository 和 ProviderRepository。API 接口保持不变。
|
#### Scenario: 创建重复模型
|
||||||
|
|
||||||
|
- **WHEN** 向 `/api/models` 发送 POST 请求,携带已存在的 provider_id + model_name 组合
|
||||||
|
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
|
||||||
|
- **THEN** 错误 SHALL 指示同一供应商下模型名称已存在
|
||||||
|
|
||||||
### Requirement: 列出所有模型
|
### Requirement: 列出所有模型
|
||||||
|
|
||||||
@@ -31,9 +37,7 @@
|
|||||||
|
|
||||||
- **WHEN** 向 `/api/models` 发送 GET 请求
|
- **WHEN** 向 `/api/models` 发送 GET 请求
|
||||||
- **THEN** 网关 SHALL 返回所有模型的列表
|
- **THEN** 网关 SHALL 返回所有模型的列表
|
||||||
- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, enabled, created_at
|
- **THEN** 每个模型 SHALL 包含 id, provider_id, model_name, unified_id, enabled, created_at
|
||||||
|
|
||||||
**变更说明:** 数据访问从 config 包迁移到 ModelRepository。API 接口保持不变。
|
|
||||||
|
|
||||||
### Requirement: 按供应商列出模型
|
### Requirement: 按供应商列出模型
|
||||||
|
|
||||||
@@ -43,8 +47,7 @@
|
|||||||
|
|
||||||
- **WHEN** 向 `/api/models?provider_id=<provider_id>` 发送 GET 请求
|
- **WHEN** 向 `/api/models?provider_id=<provider_id>` 发送 GET 请求
|
||||||
- **THEN** 网关 SHALL 返回指定供应商的模型列表
|
- **THEN** 网关 SHALL 返回指定供应商的模型列表
|
||||||
|
- **THEN** 每个模型 SHALL 包含 unified_id 字段
|
||||||
**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。
|
|
||||||
|
|
||||||
### Requirement: 更新模型配置
|
### Requirement: 更新模型配置
|
||||||
|
|
||||||
@@ -55,14 +58,12 @@
|
|||||||
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据
|
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带有效的模型数据
|
||||||
- **THEN** 网关 SHALL 更新数据库中的模型记录
|
- **THEN** 网关 SHALL 更新数据库中的模型记录
|
||||||
- **THEN** 网关 SHALL 返回更新后的模型
|
- **THEN** 网关 SHALL 返回更新后的模型
|
||||||
|
- **THEN** 返回的模型 SHALL 包含更新后的 unified_id
|
||||||
|
|
||||||
#### Scenario: 更新模型供应商
|
#### Scenario: 更新模型为重复组合
|
||||||
|
|
||||||
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,携带新的 provider_id
|
- **WHEN** 向 `/api/models/:id` 发送 PUT 请求,更新 provider_id 或 model_name 导致与已有记录重复
|
||||||
- **THEN** 网关 SHALL 验证新供应商是否存在
|
- **THEN** 网关 SHALL 返回错误,状态码为 409 (Conflict)
|
||||||
- **THEN** 网关 SHALL 更新模型的供应商关联
|
|
||||||
|
|
||||||
**变更说明:** 通过 ModelService、ModelRepository 和 ProviderRepository 实现。API 接口保持不变。
|
|
||||||
|
|
||||||
### Requirement: 删除模型配置
|
### Requirement: 删除模型配置
|
||||||
|
|
||||||
@@ -74,8 +75,6 @@
|
|||||||
- **THEN** 网关 SHALL 删除模型记录
|
- **THEN** 网关 SHALL 删除模型记录
|
||||||
- **THEN** 网关 SHALL 返回状态码 204 (No Content)
|
- **THEN** 网关 SHALL 返回状态码 204 (No Content)
|
||||||
|
|
||||||
**变更说明:** 通过 ModelService 和 ModelRepository 实现。API 接口保持不变。
|
|
||||||
|
|
||||||
### Requirement: 使用 service 层处理业务逻辑
|
### Requirement: 使用 service 层处理业务逻辑
|
||||||
|
|
||||||
Handler SHALL 通过 ModelService 处理业务逻辑。
|
Handler SHALL 通过 ModelService 处理业务逻辑。
|
||||||
@@ -85,25 +84,60 @@ Handler SHALL 通过 ModelService 处理业务逻辑。
|
|||||||
- **WHEN** handler 收到请求
|
- **WHEN** handler 收到请求
|
||||||
- **THEN** SHALL 调用对应的 ModelService 方法(Create、Get、List、Update、Delete)
|
- **THEN** SHALL 调用对应的 ModelService 方法(Create、Get、List、Update、Delete)
|
||||||
- **THEN** SHALL 使用 domain.Model 类型
|
- **THEN** SHALL 使用 domain.Model 类型
|
||||||
|
- **THEN** Create 时 SHALL 调用 `uuid.New()` 生成 id
|
||||||
|
|
||||||
#### Scenario: 供应商验证
|
#### Scenario: 供应商验证和唯一性校验
|
||||||
|
|
||||||
- **WHEN** 创建或更新模型
|
- **WHEN** 创建或更新模型
|
||||||
- **THEN** SHALL 在 service 层验证供应商存在
|
- **THEN** SHALL 在 service 层验证供应商存在
|
||||||
- **THEN** SHALL 通过 ProviderRepository 查询供应商
|
- **THEN** SHALL 在 service 层验证 provider_id + model_name 联合唯一
|
||||||
|
|
||||||
|
### Requirement: 联合唯一约束并发处理
|
||||||
|
|
||||||
|
创建或更新模型时,SHALL 使用应用层校验 + 数据库约束双重保险处理联合唯一约束。
|
||||||
|
|
||||||
|
#### Scenario: 应用层快速失败
|
||||||
|
|
||||||
|
- **WHEN** 创建或更新模型前
|
||||||
|
- **THEN** SHALL 先检查 provider_id + model_name 是否已存在
|
||||||
|
- **THEN** 如已存在,SHALL 返回 HTTP 409 Conflict
|
||||||
|
- **THEN** SHALL 返回错误格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "同一供应商下模型名称已存在",
|
||||||
|
"code": "duplicate_model"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: 数据库约束兜底
|
||||||
|
|
||||||
|
- **WHEN** 并发创建导致应用层校验通过但数据库写入失败
|
||||||
|
- **THEN** SHALL 捕获数据库 UNIQUE 约束错误
|
||||||
|
- **THEN** SHALL 转换为 HTTP 409 Conflict 错误返回
|
||||||
|
- **THEN** SHALL 返回错误格式:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "同一供应商下模型名称已存在",
|
||||||
|
"code": "duplicate_model"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Scenario: SQLite UNIQUE 约束错误检测
|
||||||
|
|
||||||
|
- **WHEN** 捕获数据库错误
|
||||||
|
- **THEN** SHALL 检查错误信息是否包含 "UNIQUE constraint failed"
|
||||||
|
- **THEN** 如匹配,SHALL 识别为联合唯一约束冲突
|
||||||
|
|
||||||
### Requirement: 使用 repository 层访问数据
|
### Requirement: 使用 repository 层访问数据
|
||||||
|
|
||||||
Service SHALL 通过 ModelRepository 访问数据。
|
Service SHALL 通过 ModelRepository 访问数据。
|
||||||
|
|
||||||
#### Scenario: 调用 repository 方法
|
#### Scenario: 联合查询
|
||||||
|
|
||||||
- **WHEN** service 处理业务逻辑
|
- **WHEN** service 需要按 provider 和 model_name 查询模型
|
||||||
- **THEN** SHALL 调用对应的 ModelRepository 方法
|
- **THEN** SHALL 调用 `FindByProviderAndModelName(providerID, modelName)` 方法
|
||||||
- **THEN** SHALL 使用 domain.Model 类型
|
|
||||||
|
|
||||||
#### Scenario: 数据验证
|
#### Scenario: 查询所有启用模型
|
||||||
|
|
||||||
- **WHEN** 创建或更新模型
|
- **WHEN** proxy handler 需要聚合模型列表
|
||||||
- **THEN** SHALL 在 service 层验证业务规则
|
- **THEN** SHALL 调用 `ListEnabled()` 方法,返回所有 enabled 的模型(关联 enabled 的供应商)
|
||||||
- **THEN** SHALL 在 repository 层执行数据库操作
|
|
||||||
|
|||||||
@@ -270,4 +270,73 @@ Decoder 几乎 1:1 映射,维护最小状态机:
|
|||||||
|
|
||||||
- **WHEN** interfaceType 为 EMBEDDINGS 或 RERANK
|
- **WHEN** interfaceType 为 EMBEDDINGS 或 RERANK
|
||||||
- **THEN** supportsInterface SHALL 返回 false
|
- **THEN** supportsInterface SHALL 返回 false
|
||||||
- **THEN** 引擎 SHALL 走透传或返回空响应
|
- **THEN** 引擎 SHALL 走透传或返回空响应
|
||||||
|
### Requirement: 模型详情路径识别
|
||||||
|
|
||||||
|
Anthropic 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
|
||||||
|
|
||||||
|
#### Scenario: 含斜杠的统一模型 ID 路径
|
||||||
|
|
||||||
|
- **WHEN** 路径为 `/v1/models/anthropic/claude-3-opus`
|
||||||
|
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
||||||
|
|
||||||
|
#### Scenario: 含多段斜杠的统一模型 ID 路径
|
||||||
|
|
||||||
|
- **WHEN** 路径为 `/v1/models/azure/deployments/gpt-4`
|
||||||
|
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
||||||
|
|
||||||
|
#### Scenario: 模型列表路径不受影响
|
||||||
|
|
||||||
|
- **WHEN** 路径为 `/v1/models`
|
||||||
|
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
|
||||||
|
|
||||||
|
### Requirement: 提取统一模型 ID
|
||||||
|
|
||||||
|
Anthropic 适配器 SHALL 从路径中提取统一模型 ID。
|
||||||
|
|
||||||
|
#### Scenario: 标准路径提取
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/anthropic/claude-3-opus")`
|
||||||
|
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
|
||||||
|
|
||||||
|
#### Scenario: 复杂路径提取
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
|
||||||
|
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
|
||||||
|
|
||||||
|
#### Scenario: 非模型详情路径
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
|
||||||
|
- **THEN** SHALL 返回错误
|
||||||
|
|
||||||
|
### Requirement: 从请求体提取 model
|
||||||
|
|
||||||
|
Anthropic 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
|
||||||
|
|
||||||
|
#### Scenario: Chat 请求提取 model
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"anthropic/claude-3-opus","messages":[...]}`
|
||||||
|
- **THEN** SHALL 返回 `"anthropic/claude-3-opus"`
|
||||||
|
|
||||||
|
#### Scenario: 无 model 字段
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段
|
||||||
|
- **THEN** SHALL 返回错误
|
||||||
|
|
||||||
|
### Requirement: 最小化改写请求体 model 字段
|
||||||
|
|
||||||
|
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
|
||||||
|
|
||||||
|
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
|
||||||
|
|
||||||
|
- **WHEN** 调用 `RewriteRequestModelName(body, "claude-3-opus-20240229", InterfaceTypeChat)`
|
||||||
|
- **THEN** SHALL 将请求体中 model 字段替换为 `"claude-3-opus-20240229"`,其余字段原样保留
|
||||||
|
|
||||||
|
### Requirement: 最小化改写响应体 model 字段
|
||||||
|
|
||||||
|
Anthropic 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
|
||||||
|
|
||||||
|
#### Scenario: 响应体 model 改写(上游名 → 统一 ID)
|
||||||
|
|
||||||
|
- **WHEN** 调用 `RewriteResponseModelName(body, "anthropic/claude-3-opus", InterfaceTypeChat)`
|
||||||
|
- **THEN** SHALL 将响应体中 model 字段替换为 `"anthropic/claude-3-opus"`,其余字段原样保留
|
||||||
|
|||||||
@@ -269,4 +269,91 @@ Encoder SHALL 维护状态:
|
|||||||
#### Scenario: /rerank 接口
|
#### Scenario: /rerank 接口
|
||||||
|
|
||||||
- **WHEN** 解码/编码 rerank 请求和响应
|
- **WHEN** 解码/编码 rerank 请求和响应
|
||||||
- **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射
|
- **THEN** SHALL 使用 CanonicalRerankRequest/Response 做字段映射
|
||||||
|
### Requirement: 模型详情路径识别
|
||||||
|
|
||||||
|
OpenAI 适配器 SHALL 正确识别包含统一模型 ID 的模型详情路径。
|
||||||
|
|
||||||
|
#### Scenario: 含斜杠的统一模型 ID 路径
|
||||||
|
|
||||||
|
- **WHEN** 路径为 `/v1/models/openai/gpt-4`
|
||||||
|
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
||||||
|
|
||||||
|
#### Scenario: 含多段斜杠的统一模型 ID 路径
|
||||||
|
|
||||||
|
- **WHEN** 路径为 `/v1/models/azure/accounts/org-123/models/gpt-4`
|
||||||
|
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModelInfo`
|
||||||
|
|
||||||
|
#### Scenario: 模型列表路径不受影响
|
||||||
|
|
||||||
|
- **WHEN** 路径为 `/v1/models`
|
||||||
|
- **THEN** `DetectInterfaceType` SHALL 返回 `InterfaceTypeModels`
|
||||||
|
|
||||||
|
### Requirement: 提取统一模型 ID
|
||||||
|
|
||||||
|
OpenAI 适配器 SHALL 从路径中提取统一模型 ID。
|
||||||
|
|
||||||
|
#### Scenario: 标准路径提取
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/openai/gpt-4")`
|
||||||
|
- **THEN** SHALL 返回 `"openai/gpt-4"`
|
||||||
|
|
||||||
|
#### Scenario: 复杂路径提取
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")`
|
||||||
|
- **THEN** SHALL 返回 `"azure/accounts/org/models/gpt-4"`
|
||||||
|
|
||||||
|
#### Scenario: 非模型详情路径
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractUnifiedModelID("/v1/models")`
|
||||||
|
- **THEN** SHALL 返回错误
|
||||||
|
|
||||||
|
### Requirement: 从请求体提取 model
|
||||||
|
|
||||||
|
OpenAI 适配器 SHALL 按 InterfaceType 从请求体中提取 model 值。
|
||||||
|
|
||||||
|
#### Scenario: Chat 请求提取 model
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...]}`
|
||||||
|
- **THEN** SHALL 返回 `"openai/gpt-4"`
|
||||||
|
|
||||||
|
#### Scenario: Embedding 请求提取 model
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractModelName(body, InterfaceTypeEmbeddings)`,body 为 `{"model":"openai/text-embedding-3","input":"text"}`
|
||||||
|
- **THEN** SHALL 返回 `"openai/text-embedding-3"`
|
||||||
|
|
||||||
|
#### Scenario: 无 model 字段
|
||||||
|
|
||||||
|
- **WHEN** 调用 `ExtractModelName(body, ifaceType)`,body 中不含 model 字段
|
||||||
|
- **THEN** SHALL 返回错误
|
||||||
|
|
||||||
|
### Requirement: 最小化改写请求体 model 字段
|
||||||
|
|
||||||
|
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写请求体中的 model 字段,其余字段保持原始 bytes。
|
||||||
|
|
||||||
|
#### Scenario: 请求体 model 改写(统一 ID → 上游名)
|
||||||
|
|
||||||
|
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeChat)`,body 为 `{"model":"openai/gpt-4","messages":[...],"some_param":"value"}`
|
||||||
|
- **THEN** SHALL 返回 `{"model":"gpt-4","messages":[...],"some_param":"value"}`
|
||||||
|
- **THEN** 除 model 外的字段 SHALL 保持原始 bytes 不变
|
||||||
|
|
||||||
|
#### Scenario: 不同 InterfaceType 的请求改写
|
||||||
|
|
||||||
|
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeEmbeddings)`
|
||||||
|
- **THEN** SHALL 按 Embedding 接口的请求体 model 字段位置进行改写
|
||||||
|
- **WHEN** 调用 `RewriteRequestModelName(body, "gpt-4", InterfaceTypeRerank)`
|
||||||
|
- **THEN** SHALL 按 Rerank 接口的请求体 model 字段位置进行改写
|
||||||
|
|
||||||
|
### Requirement: 最小化改写响应体 model 字段
|
||||||
|
|
||||||
|
OpenAI 适配器 SHALL 按 InterfaceType 用 `json.RawMessage` 最小化改写响应体中的 model 字段,其余字段保持原始 bytes。请求体和响应体的 model 字段位置可能不同,各自独立实现。
|
||||||
|
|
||||||
|
#### Scenario: 响应体 model 改写(上游名 → 统一 ID)
|
||||||
|
|
||||||
|
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeChat)`,body 为上游 Chat 响应
|
||||||
|
- **THEN** SHALL 将 model 字段替换为 `"openai/gpt-4"`,其余字段原样保留
|
||||||
|
|
||||||
|
#### Scenario: 不同 InterfaceType 的响应改写
|
||||||
|
|
||||||
|
- **WHEN** 调用 `RewriteResponseModelName(body, "openai/gpt-4", InterfaceTypeEmbeddings)`
|
||||||
|
- **THEN** SHALL 按 Embedding 接口的响应体 model 字段位置进行改写
|
||||||
|
|||||||
@@ -29,7 +29,44 @@
|
|||||||
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
||||||
- **THEN** 错误 SHALL 指示缺少哪些字段
|
- **THEN** 错误 SHALL 指示缺少哪些字段
|
||||||
|
|
||||||
**变更说明:** handler 通过 ProviderService 调用,数据访问通过 ProviderRepository。API 接口保持不变。
|
#### Scenario: 创建供应商时 ID 包含非法字符
|
||||||
|
|
||||||
|
- **WHEN** 向 `/api/providers` 发送 POST 请求,id 包含非 `[a-zA-Z0-9_]` 字符
|
||||||
|
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
||||||
|
- **THEN** 错误 SHALL 指示 id 仅允许字母、数字、下划线
|
||||||
|
|
||||||
|
#### Scenario: 创建供应商时 ID 过长
|
||||||
|
|
||||||
|
- **WHEN** 向 `/api/providers` 发送 POST 请求,id 长度超过 64
|
||||||
|
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
||||||
|
|
||||||
|
### Requirement: 供应商 ID 不允许修改
|
||||||
|
|
||||||
|
供应商 ID 是主键,用于构建统一模型 ID,不允许修改。
|
||||||
|
|
||||||
|
#### Scenario: 尝试修改供应商 ID
|
||||||
|
|
||||||
|
- **WHEN** 向 `/api/providers/:id` 发送 PUT 请求,请求体中包含 `id` 字段
|
||||||
|
- **THEN** ProviderService.Update SHALL 在 service 层校验并返回错误
|
||||||
|
- **THEN** 网关 SHALL 返回错误,状态码为 400 (Bad Request)
|
||||||
|
- **THEN** 错误 SHALL 指示供应商 ID 不允许修改
|
||||||
|
- **THEN** 错误格式 SHALL 为:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "供应商 ID 不允许修改",
|
||||||
|
"code": "IMMUTABLE_FIELD"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**校验位置:** Service 层(`ProviderService.Update`)
|
||||||
|
- Service 层校验保证所有调用方(handler、CLI、未来 API)统一遵守
|
||||||
|
- Handler 层负责捕获 `ErrImmutableField` 并转换为 HTTP 400 响应
|
||||||
|
|
||||||
|
**原因:**
|
||||||
|
- 供应商 ID 是主键,用于构建统一模型 ID(provider_id/model_name)
|
||||||
|
- 修改 ID 会导致所有统一模型 ID 失效
|
||||||
|
- 客户端缓存的模型 ID 全部失效
|
||||||
|
- 如需修改,应创建新供应商并迁移模型
|
||||||
|
|
||||||
### Requirement: 列出所有供应商
|
### Requirement: 列出所有供应商
|
||||||
|
|
||||||
|
|||||||
@@ -92,6 +92,44 @@
|
|||||||
- **THEN** SHALL 验证至少提供一个可更新字段
|
- **THEN** SHALL 验证至少提供一个可更新字段
|
||||||
- **THEN** SHALL 验证字段值有效性
|
- **THEN** SHALL 验证字段值有效性
|
||||||
|
|
||||||
|
### Requirement: 供应商 ID 校验
|
||||||
|
|
||||||
|
创建供应商时,SHALL 对 `id` 字段进行字符集校验。
|
||||||
|
|
||||||
|
#### Scenario: 合法字符集
|
||||||
|
|
||||||
|
- **WHEN** 创建供应商,id 仅包含 `[a-zA-Z0-9_]` 字符
|
||||||
|
- **THEN** SHALL 校验通过
|
||||||
|
|
||||||
|
#### Scenario: 非法字符
|
||||||
|
|
||||||
|
- **WHEN** 创建供应商,id 包含 `-`、`.`、`/`、空格、中文等非 `[a-zA-Z0-9_]` 字符
|
||||||
|
- **THEN** SHALL 返回 400 错误
|
||||||
|
|
||||||
|
#### Scenario: 长度限制
|
||||||
|
|
||||||
|
- **WHEN** 创建供应商,id 长度超过 64
|
||||||
|
- **THEN** SHALL 返回 400 错误
|
||||||
|
|
||||||
|
### Requirement: 模型创建校验
|
||||||
|
|
||||||
|
创建模型时,SHALL 对 `provider_id` + `model_name` 进行联合唯一性校验。
|
||||||
|
|
||||||
|
#### Scenario: 正常创建
|
||||||
|
|
||||||
|
- **WHEN** 创建模型,provider_id 存在且 provider_id + model_name 组合唯一
|
||||||
|
- **THEN** SHALL 校验通过
|
||||||
|
|
||||||
|
#### Scenario: 联合唯一冲突
|
||||||
|
|
||||||
|
- **WHEN** 创建模型,provider_id + model_name 组合已存在
|
||||||
|
- **THEN** SHALL 返回 409 错误
|
||||||
|
|
||||||
|
#### Scenario: model_name 为空
|
||||||
|
|
||||||
|
- **WHEN** 创建模型,未提供 model_name
|
||||||
|
- **THEN** SHALL 返回 400 错误
|
||||||
|
|
||||||
### Requirement: 返回友好的验证错误
|
### Requirement: 返回友好的验证错误
|
||||||
|
|
||||||
系统 SHALL 返回友好的验证错误响应。
|
系统 SHALL 返回友好的验证错误响应。
|
||||||
|
|||||||
@@ -1,4 +1,10 @@
|
|||||||
## ADDED Requirements
|
# Unified Model ID
|
||||||
|
|
||||||
|
## Purpose
|
||||||
|
|
||||||
|
定义统一模型 ID 的格式、解析、格式化和校验规则,确保跨协议的模型标识一致性。
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
### Requirement: 解析统一模型 ID
|
### Requirement: 解析统一模型 ID
|
||||||
|
|
||||||
@@ -105,4 +105,125 @@ ProxyHandler SHALL 支持 GET 请求的扩展层接口代理。
|
|||||||
- **THEN** SHALL 调用 engine.convertHttpRequest(GET 请求 body 为空)
|
- **THEN** SHALL 调用 engine.convertHttpRequest(GET 请求 body 为空)
|
||||||
- **THEN** SHALL 调用 providerClient.Send 发送请求
|
- **THEN** SHALL 调用 providerClient.Send 发送请求
|
||||||
- **THEN** SHALL 调用 engine.convertHttpResponse 转换响应格式
|
- **THEN** SHALL 调用 engine.convertHttpResponse 转换响应格式
|
||||||
- **THEN** SHALL 返回转换后的响应
|
- **THEN** SHALL 返回转换后的响应
|
||||||
|
### Requirement: 代理请求路由
|
||||||
|
|
||||||
|
ProxyHandler SHALL 使用统一模型 ID 路由所有代理请求。
|
||||||
|
|
||||||
|
#### Scenario: 提取统一模型 ID
|
||||||
|
|
||||||
|
- **WHEN** 收到 Chat、Embeddings 或 Rerank 接口的 POST 请求(含请求体)
|
||||||
|
- **THEN** SHALL 调用客户端协议 adapter 的 `ExtractModelName(body, ifaceType)` 提取 model 值
|
||||||
|
- **THEN** SHALL 调用 `ParseUnifiedModelID` 解析得到 providerID 和 modelName
|
||||||
|
- **THEN** SHALL 调用 `RoutingService.RouteByModelName(providerID, modelName)` 路由
|
||||||
|
|
||||||
|
#### Scenario: GET 请求或无请求体
|
||||||
|
|
||||||
|
- **WHEN** 收到 GET 请求或请求体为空或请求体中无法提取 model 字段
|
||||||
|
- **THEN** SHALL 走 forwardPassthrough 透传到上游供应商(兼容未适配的客户端和无 body 请求)
|
||||||
|
|
||||||
|
#### Scenario: 无效的统一模型 ID
|
||||||
|
|
||||||
|
- **WHEN** 请求体中 `model` 字段不是有效的统一模型 ID 格式(不含 `/`)
|
||||||
|
- **THEN** SHALL 走 forwardPassthrough 透传到上游供应商(兼容使用原始模型名的客户端)
|
||||||
|
|
||||||
|
#### Scenario: 模型不存在
|
||||||
|
|
||||||
|
- **WHEN** 解析统一模型 ID 后,数据库中找不到对应的 provider_id + model_name 组合
|
||||||
|
- **THEN** SHALL 返回错误响应,状态码为 404
|
||||||
|
|
||||||
|
#### Scenario: 模型已禁用
|
||||||
|
|
||||||
|
- **WHEN** 解析统一模型 ID 后,对应的模型 enabled 为 false
|
||||||
|
- **THEN** SHALL 返回错误响应,状态码为 404
|
||||||
|
|
||||||
|
#### Scenario: 供应商已禁用
|
||||||
|
|
||||||
|
- **WHEN** 解析统一模型 ID 后,对应的供应商 enabled 为 false
|
||||||
|
- **THEN** SHALL 返回错误响应,状态码为 404
|
||||||
|
|
||||||
|
### Requirement: 同协议 Smart Passthrough
|
||||||
|
|
||||||
|
当客户端协议与供应商协议相同时,ProxyHandler SHALL 使用 Smart Passthrough 处理 Chat、Embedding、Rerank 请求。
|
||||||
|
|
||||||
|
#### Scenario: 同协议非流式请求
|
||||||
|
|
||||||
|
- **WHEN** 客户端协议 == 供应商协议,且为非流式请求
|
||||||
|
- **THEN** SHALL 调用 adapter 的 `RewriteRequestModelName(body, modelName, ifaceType)` 将请求体中 model 从统一 ID 改写为上游模型名
|
||||||
|
- **THEN** SHALL 构建 URL 和 Headers(同当前透传逻辑)
|
||||||
|
- **THEN** SHALL 发送改写后的请求体到上游
|
||||||
|
- **THEN** SHALL 调用 adapter 的 `RewriteResponseModelName(resp.Body, unifiedModelID, ifaceType)` 将响应中 model 从上游名改写为统一 ID
|
||||||
|
- **THEN** SHALL NOT 对 body 做全量 decode → encode,保持未改写字段的原始 bytes
|
||||||
|
|
||||||
|
#### Scenario: 同协议流式请求
|
||||||
|
|
||||||
|
- **WHEN** 客户端协议 == 供应商协议,且为流式请求
|
||||||
|
- **THEN** SHALL 对请求体做 `RewriteRequestModelName` 改写 model 字段
|
||||||
|
- **THEN** SHALL 逐 SSE chunk 调用 `RewriteResponseModelName` 改写响应中 model 字段
|
||||||
|
- **THEN** SHALL NOT 对 chunk 做全量 decode → encode
|
||||||
|
|
||||||
|
#### Scenario: Smart Passthrough 保真性
|
||||||
|
|
||||||
|
- **WHEN** 客户端发送含未知参数的请求(如 `{"model":"openai/gpt-4","some_new_param":"value"}`)
|
||||||
|
- **THEN** 上游 SHALL 收到 `{"model":"gpt-4","some_new_param":"value"}`
|
||||||
|
- **THEN** `some_new_param` SHALL 保持原始值不变,不丢失、不改变类型
|
||||||
|
|
||||||
|
### Requirement: 跨协议完整转换
|
||||||
|
|
||||||
|
当客户端协议与供应商协议不同时,ProxyHandler SHALL 使用全量转换路径。
|
||||||
|
|
||||||
|
#### Scenario: 跨协议非流式请求
|
||||||
|
|
||||||
|
- **WHEN** 客户端协议 != 供应商协议
|
||||||
|
- **THEN** SHALL 走 `ConvertHttpRequest` 全量转换,encoder 中 provider.ModelName 覆盖 model
|
||||||
|
- **THEN** SHALL 走 `ConvertHttpResponse` 全量转换,modelOverride 参数覆写 canonical.Model
|
||||||
|
|
||||||
|
#### Scenario: 跨协议流式请求
|
||||||
|
|
||||||
|
- **WHEN** 客户端协议 != 供应商协议,且为流式请求
|
||||||
|
- **THEN** SHALL 走 `CreateStreamConverter` 全量转换,modelOverride 参数覆写流式 canonical 事件中的 Model
|
||||||
|
|
||||||
|
### Requirement: 模型列表本地聚合
|
||||||
|
|
||||||
|
ProxyHandler SHALL 从数据库聚合返回模型列表,不再透传上游。
|
||||||
|
|
||||||
|
#### Scenario: GET /v1/models
|
||||||
|
|
||||||
|
- **WHEN** 收到 `GET /{protocol}/v1/models` 请求
|
||||||
|
- **THEN** SHALL 从数据库查询所有 enabled 的模型(关联 enabled 的供应商)
|
||||||
|
- **THEN** SHALL 组装 `CanonicalModelList`,每个模型的 ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id
|
||||||
|
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
|
||||||
|
- **THEN** SHALL NOT 请求上游供应商
|
||||||
|
|
||||||
|
#### Scenario: 无可用模型
|
||||||
|
|
||||||
|
- **WHEN** 数据库中没有 enabled 的模型
|
||||||
|
- **THEN** SHALL 返回空列表
|
||||||
|
|
||||||
|
### Requirement: 模型详情本地查询
|
||||||
|
|
||||||
|
ProxyHandler SHALL 从数据库查询返回模型详情,不再透传上游。
|
||||||
|
|
||||||
|
#### Scenario: GET /v1/models/{unified_id}
|
||||||
|
|
||||||
|
- **WHEN** 收到 `GET /{protocol}/v1/models/{provider_id}/{model_name}` 请求
|
||||||
|
- **THEN** SHALL 调用 adapter 的 `ExtractUnifiedModelID` 提取统一模型 ID
|
||||||
|
- **THEN** SHALL 解析统一模型 ID 得到 providerID 和 modelName
|
||||||
|
- **THEN** SHALL 从数据库查询对应的模型和供应商
|
||||||
|
- **THEN** SHALL 组装 `CanonicalModelInfo`,ID 字段为统一模型 ID(`provider_id/model_name`),Name 字段为 model_name,OwnedBy 字段为 provider_id
|
||||||
|
- **THEN** SHALL 使用客户端协议的 adapter 编码响应
|
||||||
|
- **THEN** SHALL NOT 请求上游供应商
|
||||||
|
|
||||||
|
#### Scenario: 模型详情不存在
|
||||||
|
|
||||||
|
- **WHEN** 统一模型 ID 对应的模型不存在或已禁用
|
||||||
|
- **THEN** SHALL 返回错误响应,状态码为 404
|
||||||
|
|
||||||
|
### Requirement: 统计记录
|
||||||
|
|
||||||
|
ProxyHandler SHALL 使用 providerID 和 modelName 记录使用统计。
|
||||||
|
|
||||||
|
#### Scenario: 异步记录统计
|
||||||
|
|
||||||
|
- **WHEN** 代理请求成功完成
|
||||||
|
- **THEN** SHALL 异步调用 `StatsService.Record(providerID, modelName)`
|
||||||
|
|||||||
@@ -22,7 +22,20 @@
|
|||||||
- **THEN** 网关 SHALL 增加该供应商和模型的请求计数
|
- **THEN** 网关 SHALL 增加该供应商和模型的请求计数
|
||||||
- **THEN** 网关 SHALL 在流结束后记录统计
|
- **THEN** 网关 SHALL 在流结束后记录统计
|
||||||
|
|
||||||
**变更说明:** 统计记录通过 StatsService 调用,数据访问通过 StatsRepository。API 接口保持不变。
|
### Requirement: 使用统计记录统一模型标识
|
||||||
|
|
||||||
|
系统 SHALL 使用 providerID 和 modelName(上游模型名)记录使用统计。
|
||||||
|
|
||||||
|
#### Scenario: 代理请求统计记录
|
||||||
|
|
||||||
|
- **WHEN** 代理请求成功完成
|
||||||
|
- **THEN** SHALL 记录 provider_id 和 model_name 到 usage_stats 表(参数来自路由结果)
|
||||||
|
- **THEN** SHALL 异步执行,不阻塞响应
|
||||||
|
|
||||||
|
#### Scenario: 查询统计
|
||||||
|
|
||||||
|
- **WHEN** 查询统计数据
|
||||||
|
- **THEN** 支持按 provider_id 和 model_name 过滤
|
||||||
|
|
||||||
### Requirement: 按供应商查询统计
|
### Requirement: 按供应商查询统计
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user