Compare commits
24 Commits
59179094ed
...
380586afa6
| Author | SHA1 | Date | |
|---|---|---|---|
| 380586afa6 | |||
| ebb70809bf | |||
| 7399afbc5c | |||
| c0669e4b07 | |||
| 05c04091b3 | |||
| 0b05e08705 | |||
| df253559a5 | |||
| 669cbb8c51 | |||
| 5ae9d85272 | |||
| 72aebef625 | |||
| f5e45d032e | |||
| b03e5f809f | |||
| ec563aaa16 | |||
| 873f09d3bf | |||
| 5e7267db07 | |||
| 7b28cee7a1 | |||
| 934c8dea77 | |||
| 7d91fe345e | |||
| 4e86adffb7 | |||
| 5d58acf5a6 | |||
| 81dcecb723 | |||
| 141f5f886f | |||
| 7fa5af483b | |||
| f488b9cc15 |
7
.gitignore
vendored
7
.gitignore
vendored
@@ -405,4 +405,9 @@ openspec/changes/archive
|
||||
temp
|
||||
.agents
|
||||
skills-lock.json
|
||||
.worktrees
|
||||
.worktrees
|
||||
!scripts/build/
|
||||
|
||||
# Embedfs generated
|
||||
embedfs/assets/
|
||||
embedfs/frontend-dist/
|
||||
115
Makefile
Normal file
115
Makefile
Normal file
@@ -0,0 +1,115 @@
|
||||
.PHONY: all clean \
|
||||
backend-build backend-run backend-test backend-test-unit backend-test-integration backend-test-coverage \
|
||||
backend-lint backend-deps backend-generate \
|
||||
backend-migrate-up backend-migrate-down backend-migrate-status backend-migrate-create \
|
||||
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint \
|
||||
desktop desktop-darwin desktop-windows desktop-linux package-macos
|
||||
|
||||
# ============================================
|
||||
# 后端
|
||||
# ============================================
|
||||
|
||||
all: backend-build
|
||||
|
||||
backend-build:
|
||||
cd backend && go build -o bin/server ./cmd/server
|
||||
|
||||
backend-run:
|
||||
cd backend && go run ./cmd/server
|
||||
|
||||
backend-test:
|
||||
cd backend && go test ./... -v
|
||||
|
||||
backend-test-unit:
|
||||
cd backend && go test ./internal/... ./pkg/... -v
|
||||
|
||||
backend-test-integration:
|
||||
cd backend && go test ./tests/... -v
|
||||
|
||||
backend-test-coverage:
|
||||
cd backend && go test ./... -coverprofile=coverage.out
|
||||
cd backend && go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: backend/coverage.html"
|
||||
|
||||
backend-lint:
|
||||
cd backend && go tool golangci-lint run ./...
|
||||
|
||||
backend-deps:
|
||||
cd backend && go mod tidy
|
||||
|
||||
backend-generate:
|
||||
cd backend && go generate ./...
|
||||
|
||||
backend-migrate-up:
|
||||
cd backend && goose -dir migrations sqlite3 $(DB_PATH) up
|
||||
|
||||
backend-migrate-down:
|
||||
cd backend && goose -dir migrations sqlite3 $(DB_PATH) down
|
||||
|
||||
backend-migrate-status:
|
||||
cd backend && goose -dir migrations sqlite3 $(DB_PATH) status
|
||||
|
||||
backend-migrate-create:
|
||||
@read -p "Migration name: " name; \
|
||||
cd backend && goose -dir migrations create $$name sql
|
||||
|
||||
# ============================================
|
||||
# 前端
|
||||
# ============================================
|
||||
|
||||
frontend-build:
|
||||
cd frontend && bun install && bun run build
|
||||
|
||||
frontend-dev:
|
||||
cd frontend && bun dev
|
||||
|
||||
frontend-test:
|
||||
cd frontend && bun run test
|
||||
|
||||
frontend-test-watch:
|
||||
cd frontend && bun run test:watch
|
||||
|
||||
frontend-test-coverage:
|
||||
cd frontend && bun run test:coverage
|
||||
|
||||
frontend-test-e2e:
|
||||
cd frontend && bun run test:e2e
|
||||
|
||||
frontend-lint:
|
||||
cd frontend && bun run lint
|
||||
|
||||
# ============================================
|
||||
# 桌面应用
|
||||
# ============================================
|
||||
|
||||
desktop: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 go build -o ../build/nex ./cmd/desktop
|
||||
|
||||
frontend-build-desktop:
|
||||
cd frontend && cp .env.desktop .env.production.local && bun install && bun run build && rm -f .env.production.local
|
||||
|
||||
embedfs-prepare:
|
||||
rm -rf embedfs/assets embedfs/frontend-dist
|
||||
cp -r assets embedfs/assets
|
||||
cp -r frontend/dist embedfs/frontend-dist
|
||||
|
||||
desktop-darwin: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-darwin-arm64 ./cmd/desktop
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-darwin-amd64 ./cmd/desktop
|
||||
|
||||
desktop-windows: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-windows-amd64.exe ./cmd/desktop
|
||||
|
||||
desktop-linux: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop
|
||||
|
||||
package-macos:
|
||||
./scripts/build/package-macos.sh
|
||||
|
||||
# ============================================
|
||||
# 清理
|
||||
# ============================================
|
||||
|
||||
clean:
|
||||
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
|
||||
rm -rf build/
|
||||
148
README.md
148
README.md
@@ -7,13 +7,15 @@
|
||||
```
|
||||
nex/
|
||||
├── backend/ # Go 后端服务(分层架构)
|
||||
│ ├── cmd/server/ # 主程序入口
|
||||
│ ├── cmd/
|
||||
│ │ ├── server/ # CLI 主程序入口
|
||||
│ │ └── desktop/ # 桌面应用入口
|
||||
│ ├── internal/
|
||||
│ │ ├── handler/ # HTTP 处理器 + 中间件
|
||||
│ │ ├── service/ # 业务逻辑层
|
||||
│ │ ├── repository/ # 数据访问层
|
||||
│ │ ├── domain/ # 领域模型
|
||||
│ │ ├── protocol/ # 协议适配器(OpenAI/Anthropic)
|
||||
│ │ ├── conversion/ # 协议转换引擎(OpenAI/Anthropic 适配器)
|
||||
│ │ ├── provider/ # 供应商客户端
|
||||
│ │ └── config/ # 配置管理
|
||||
│ ├── pkg/ # 公共包(logger/errors/validator)
|
||||
@@ -32,16 +34,28 @@ nex/
|
||||
│ ├── e2e/ # Playwright E2E 测试
|
||||
│ └── package.json
|
||||
│
|
||||
├── assets/ # 应用资源
|
||||
│ ├── icon.png # 托盘图标
|
||||
│ ├── AppIcon.icns # macOS 应用图标
|
||||
│ └── icon.ico # Windows 应用图标
|
||||
│
|
||||
├── scripts/ # 构建脚本
|
||||
│ └── build/
|
||||
│ └── package-macos.sh # macOS .app 打包脚本
|
||||
│
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
||||
- **跨协议转换**:Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
|
||||
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
||||
- **透明代理**:对 OpenAI 兼容供应商 Smart Passthrough,最小化改写保持参数保真
|
||||
- **流式响应**:完整支持 SSE 流式传输
|
||||
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段
|
||||
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
|
||||
- **Function Calling**:支持工具调用(Tools)
|
||||
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
|
||||
- **扩展接口**:支持 Embeddings 和 Rerank 接口
|
||||
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
|
||||
- **用量统计**:按供应商、模型、日期统计请求数量
|
||||
- **Web 配置界面**:提供供应商和模型配置管理
|
||||
@@ -54,7 +68,7 @@ nex/
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转)
|
||||
- **配置**: gopkg.in/yaml.v3
|
||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
@@ -72,7 +86,46 @@ nex/
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 后端
|
||||
### 桌面应用(推荐)
|
||||
|
||||
**构建桌面应用**:
|
||||
|
||||
```bash
|
||||
# 当前平台
|
||||
make desktop
|
||||
|
||||
# macOS (arm64 + amd64)
|
||||
make desktop-darwin
|
||||
make package-macos # 打包为 .app
|
||||
|
||||
# Windows
|
||||
make desktop-windows
|
||||
|
||||
# Linux
|
||||
make desktop-linux
|
||||
```
|
||||
|
||||
**使用桌面应用**:
|
||||
- 双击启动应用(macOS: Nex.app,Windows: nex.exe,Linux: nex)
|
||||
- 系统托盘图标出现,浏览器自动打开管理界面
|
||||
- 点击托盘图标显示菜单,可打开管理界面或退出
|
||||
- 关闭浏览器后服务继续运行,可通过托盘重新打开
|
||||
|
||||
**注意事项**:
|
||||
- 桌面应用需要 CGO 支持
|
||||
- macOS: 自带 Xcode Command Line Tools
|
||||
- Linux: 自带 gcc,部分桌面环境需要 `libappindicator3-dev`
|
||||
- Windows: 需要 MinGW-w64 或在 Windows 环境构建
|
||||
|
||||
**Linux 桌面环境兼容性**:
|
||||
- GNOME: 需要 AppIndicator 扩展
|
||||
- KDE Plasma: 原生支持
|
||||
- Xfce: 需要 libappindicator
|
||||
- 其他支持 StatusNotifierItem 规范的环境
|
||||
|
||||
### CLI 模式
|
||||
|
||||
#### 后端
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
@@ -100,12 +153,19 @@ bun dev
|
||||
|
||||
### 代理接口(对外部应用)
|
||||
|
||||
代理接口统一使用 `provider_id/model_name` 格式的模型 ID(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
|
||||
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
|
||||
|
||||
- `POST /v1/chat/completions` - OpenAI Chat Completions API
|
||||
- `POST /v1/messages` - Anthropic Messages API
|
||||
- `GET /v1/models` - 模型列表(本地数据库聚合,不请求上游)
|
||||
- `GET /v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
**OpenAI 协议**(`protocol=openai`):
|
||||
- `POST /openai/chat/completions` - 对话补全
|
||||
- `GET /openai/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
- `POST /openai/embeddings` - 嵌入
|
||||
- `POST /openai/rerank` - 重排序
|
||||
|
||||
**Anthropic 协议**(`protocol=anthropic`):
|
||||
- `POST /anthropic/v1/messages` - 消息对话
|
||||
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
|
||||
### 管理接口(对前端)
|
||||
|
||||
@@ -131,6 +191,10 @@ bun dev
|
||||
|
||||
## 配置
|
||||
|
||||
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成:
|
||||
|
||||
```yaml
|
||||
@@ -154,48 +218,52 @@ log:
|
||||
compress: true
|
||||
```
|
||||
|
||||
数据文件:
|
||||
### 环境变量
|
||||
|
||||
所有配置项支持环境变量,使用 `NEX_` 前缀:
|
||||
|
||||
```bash
|
||||
export NEX_SERVER_PORT=9000
|
||||
export NEX_DATABASE_PATH=/data/nex.db
|
||||
export NEX_LOG_LEVEL=debug
|
||||
```
|
||||
|
||||
命名规则:配置路径转大写 + 下划线(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||
|
||||
### CLI 参数
|
||||
|
||||
```bash
|
||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||
```
|
||||
|
||||
命名规则:配置路径转 kebab-case(如 `server.port` → `--server-port`)。
|
||||
|
||||
### 数据文件
|
||||
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
## 测试
|
||||
|
||||
### 后端测试
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
make test # 运行所有测试
|
||||
make test-coverage # 生成覆盖率报告
|
||||
```
|
||||
|
||||
### 前端测试
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun run test # 单元测试 + 组件测试
|
||||
bun run test:watch # 监听模式
|
||||
bun run test:coverage # 生成覆盖率报告
|
||||
bun run test:e2e # E2E 测试
|
||||
make backend-test # 后端测试
|
||||
make backend-test-coverage # 后端覆盖率
|
||||
make frontend-test # 前端测试
|
||||
make frontend-test-e2e # 前端 E2E 测试
|
||||
```
|
||||
|
||||
## 开发
|
||||
|
||||
### 后端开发
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
make build # 构建
|
||||
make lint # 代码检查
|
||||
make migrate-up # 数据库迁移
|
||||
```
|
||||
make backend-build # 构建后端
|
||||
make backend-run # 运行后端
|
||||
make backend-lint # 后端代码检查
|
||||
make backend-migrate-up # 数据库迁移
|
||||
|
||||
### 前端开发
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun run build # 构建生产版本
|
||||
bun run lint # 代码检查
|
||||
make frontend-build # 构建前端
|
||||
make frontend-dev # 前端开发模式
|
||||
make frontend-lint # 前端代码检查
|
||||
```
|
||||
|
||||
## 开发规范
|
||||
|
||||
BIN
assets/AppIcon.icns
Normal file
BIN
assets/AppIcon.icns
Normal file
Binary file not shown.
64
assets/README.md
Normal file
64
assets/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Assets
|
||||
|
||||
应用资源文件目录。
|
||||
|
||||
## 文件说明
|
||||
|
||||
| 文件 | 用途 | 尺寸 | 格式 |
|
||||
|------|------|------|------|
|
||||
| `icon.svg` | 源图标 | 64x64 | SVG |
|
||||
| `icon.png` | 托盘图标 | 64x64 | PNG |
|
||||
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
|
||||
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
|
||||
|
||||
## 替换图标
|
||||
|
||||
### 1. 准备图标
|
||||
|
||||
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
|
||||
|
||||
### 2. 生成各平台图标
|
||||
|
||||
**托盘图标 (PNG)**:
|
||||
```bash
|
||||
magick your-icon.svg -resize 64x64 icon.png
|
||||
```
|
||||
|
||||
**macOS 应用图标 (ICNS)**:
|
||||
```bash
|
||||
mkdir icon.iconset
|
||||
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
|
||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
|
||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
|
||||
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
|
||||
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
|
||||
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
|
||||
iconutil -c icns icon.iconset -o AppIcon.icns
|
||||
rm -rf icon.iconset
|
||||
```
|
||||
|
||||
**Windows 应用图标 (ICO)**:
|
||||
```bash
|
||||
magick your-icon.svg -resize 256x256 icon.ico
|
||||
```
|
||||
|
||||
### 3. 替换文件
|
||||
|
||||
将生成的文件放入此目录,然后重新构建桌面应用:
|
||||
```bash
|
||||
./scripts/build/build-darwin-arm64.sh
|
||||
```
|
||||
|
||||
## macOS Template 图标
|
||||
|
||||
macOS 支持 Template 图标,自动适配深浅色模式:
|
||||
- 使用黑色 + 透明设计
|
||||
- 文件名以 `Template` 结尾(如 `iconTemplate.png`)
|
||||
- 黑色在深色模式下自动变为白色
|
||||
|
||||
## 设计建议
|
||||
|
||||
- 托盘图标应简洁,在小尺寸下清晰可辨
|
||||
- 避免过多细节和文字
|
||||
- 使用高对比度颜色
|
||||
- macOS 建议使用 Template 图标风格
|
||||
BIN
assets/icon.ico
Normal file
BIN
assets/icon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 264 KiB |
BIN
assets/icon.png
Normal file
BIN
assets/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.0 KiB |
13
assets/icon.svg
Normal file
13
assets/icon.svg
Normal file
@@ -0,0 +1,13 @@
|
||||
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
|
||||
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
|
||||
<circle cx="32" cy="32" r="6" fill="white"/>
|
||||
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
|
||||
<circle cx="20" cy="20" r="3" fill="white"/>
|
||||
<circle cx="44" cy="20" r="3" fill="white"/>
|
||||
<circle cx="20" cy="44" r="3" fill="white"/>
|
||||
<circle cx="44" cy="44" r="3" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 779 B |
@@ -1,45 +0,0 @@
|
||||
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint
|
||||
|
||||
# 构建
|
||||
build:
|
||||
go build -o bin/server ./cmd/server
|
||||
|
||||
# 运行
|
||||
run:
|
||||
go run ./cmd/server
|
||||
|
||||
# 测试
|
||||
test:
|
||||
go test ./... -v
|
||||
|
||||
# 测试覆盖率
|
||||
test-coverage:
|
||||
go test ./... -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: coverage.html"
|
||||
|
||||
# 清理
|
||||
clean:
|
||||
rm -rf bin/ coverage.out coverage.html
|
||||
|
||||
# 数据库迁移
|
||||
migrate-up:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) up
|
||||
|
||||
migrate-down:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) down
|
||||
|
||||
migrate-status:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) status
|
||||
|
||||
migrate-create:
|
||||
@read -p "Migration name: " name; \
|
||||
goose -dir migrations create $$name sql
|
||||
|
||||
# 代码检查
|
||||
lint:
|
||||
golangci-lint run ./...
|
||||
|
||||
# 安装依赖
|
||||
deps:
|
||||
go mod tidy
|
||||
@@ -116,8 +116,11 @@ backend/
|
||||
├── migrations/ # 数据库迁移
|
||||
│ └── 20260421000001_initial_schema.sql
|
||||
├── tests/ # 集成测试
|
||||
│ ├── helpers.go
|
||||
│ └── integration/
|
||||
│ ├── helpers.go # 测试辅助函数
|
||||
│ ├── config/ # 测试配置
|
||||
│ ├── integration/ # 集成测试
|
||||
│ │ └── e2e_conversion_test.go # E2E 协议转换测试
|
||||
│ └── mocks/ # Mock 实现
|
||||
├── Makefile
|
||||
├── go.mod
|
||||
└── README.md
|
||||
@@ -146,6 +149,120 @@ Client Request (clientProtocol)
|
||||
|
||||
同协议时自动透传,跳过序列化开销。
|
||||
|
||||
## 协议转换架构
|
||||
|
||||
### Canonical Model 中间表示
|
||||
|
||||
所有协议转换都经过 Canonical Model 中间表示层,实现 Hub-and-Spoke 架构:
|
||||
|
||||
```
|
||||
OpenAI Request → Canonical Request → Anthropic Request
|
||||
(中间表示)
|
||||
OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
```
|
||||
|
||||
**CanonicalRequest 核心字段**:
|
||||
- `Model` - 统一模型 ID
|
||||
- `Messages` - 消息列表(支持 text、tool_use、tool_result、thinking 类型)
|
||||
- `Tools` - 工具定义
|
||||
- `Thinking` - 推理配置(`budget_tokens`、`effort`)
|
||||
- `Parameters` - 通用参数(`max_tokens`、`temperature`、`top_p` 等)
|
||||
|
||||
### Smart Passthrough 机制
|
||||
|
||||
同协议请求走 Smart Passthrough 路径,**零序列化开销**:
|
||||
|
||||
```
|
||||
1. 检测 clientProtocol == providerProtocol
|
||||
2. 仅改写请求体中的 model 字段:unified_id → upstream_model_name
|
||||
3. 直接转发请求到上游
|
||||
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
||||
```
|
||||
|
||||
### 流式转换器层次
|
||||
|
||||
```
|
||||
StreamConverter (接口)
|
||||
├── PassthroughStreamConverter # 直接透传,无任何处理
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model
|
||||
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
||||
```
|
||||
|
||||
### InterfaceType 枚举
|
||||
|
||||
| 类型 | 说明 |
|
||||
|------|------|
|
||||
| `CHAT` | 对话补全(chat/completions、messages) |
|
||||
| `MODELS` | 模型列表 |
|
||||
| `MODEL_INFO` | 模型详情 |
|
||||
| `EMBEDDINGS` | 嵌入接口 |
|
||||
| `RERANK` | 重排序接口 |
|
||||
| `PASSTHROUGH` | 未知接口,直接透传 |
|
||||
|
||||
## 协议适配器特性
|
||||
|
||||
### OpenAI 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `reasoning_effort` - 映射到 Canonical Thinking 配置(`none` → 禁用,其他 → `effort`)
|
||||
- `reasoning_content` - 非标准字段,映射到 Canonical thinking 块
|
||||
- `max_completion_tokens` - 新字段,优先于 `max_tokens`
|
||||
- `refusal` - 非标准字段,作为 text 块处理
|
||||
|
||||
**废弃字段兼容**:
|
||||
- `functions` / `function_call` - 自动转换为 `tools` / `tool_choice`
|
||||
|
||||
**消息处理**:
|
||||
- 合并连续同角色消息(Anthropic 不支持连续同角色)
|
||||
- 工具选择映射:`any` → `required`
|
||||
|
||||
### Anthropic 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `thinking` - 推理配置(`type: enabled`、`budget_tokens`、`effort`)
|
||||
- `output_config` - 结构化输出配置
|
||||
- `disable_parallel_tool_use` - 禁用并行工具调用
|
||||
- `container` - 工具容器字段
|
||||
|
||||
**不支持的功能**:
|
||||
- Embeddings 接口(返回 `INTERFACE_NOT_SUPPORTED` 错误)
|
||||
|
||||
### 跨协议转换注意事项
|
||||
|
||||
| 源协议 | 目标协议 | 转换说明 |
|
||||
|--------|----------|----------|
|
||||
| OpenAI | Anthropic | `reasoning_effort` → `thinking`,消息角色合并 |
|
||||
| Anthropic | OpenAI | `thinking` 块 → `reasoning_content`,工具选择转换 |
|
||||
|
||||
## 错误码
|
||||
|
||||
### ConversionError 错误码
|
||||
|
||||
| 错误码 | 说明 |
|
||||
|--------|------|
|
||||
| `INVALID_INPUT` | 输入数据无效 |
|
||||
| `MISSING_REQUIRED_FIELD` | 缺少必填字段 |
|
||||
| `INCOMPATIBLE_FEATURE` | 功能不兼容(如跨协议不支持某特性) |
|
||||
| `FIELD_MAPPING_FAILURE` | 字段映射失败 |
|
||||
| `TOOL_CALL_PARSE_ERROR` | 工具调用解析错误 |
|
||||
| `JSON_PARSE_ERROR` | JSON 解析错误 |
|
||||
| `STREAM_STATE_ERROR` | 流式状态错误 |
|
||||
| `UTF8_DECODE_ERROR` | UTF-8 解码错误(流式 chunk 截断) |
|
||||
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
||||
| `ENCODING_FAILURE` | 编码失败 |
|
||||
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
||||
|
||||
### AppError 预定义错误
|
||||
|
||||
| 错误 | HTTP 状态码 | 说明 |
|
||||
|------|-------------|------|
|
||||
| `ErrModelNotFound` | 404 | 模型未找到 |
|
||||
| `ErrModelDisabled` | 404 | 模型已禁用 |
|
||||
| `ErrProviderNotFound` | 404 | 供应商未找到 |
|
||||
| `ErrInvalidProviderID` | 400 | 供应商 ID 格式无效 |
|
||||
| `ErrDuplicateModel` | 409 | 同一供应商下模型名称重复 |
|
||||
| `ErrImmutableField` | 400 | 不可修改字段(如供应商 ID) |
|
||||
|
||||
## 运行方式
|
||||
|
||||
### 安装依赖
|
||||
@@ -278,10 +395,10 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
#### OpenAI 协议
|
||||
|
||||
```
|
||||
POST /openai/v1/chat/completions
|
||||
GET /openai/v1/models
|
||||
POST /openai/v1/embeddings
|
||||
POST /openai/v1/rerank
|
||||
POST /openai/chat/completions
|
||||
GET /openai/models
|
||||
POST /openai/embeddings
|
||||
POST /openai/rerank
|
||||
```
|
||||
|
||||
#### Anthropic 协议
|
||||
@@ -322,7 +439,7 @@ GET /anthropic/v1/models
|
||||
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
|
||||
|
||||
**对外 URL 格式**:
|
||||
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions`、`/openai/models`
|
||||
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions`、`/openai/models`、`/openai/embeddings`
|
||||
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages`、`/anthropic/v1/models`
|
||||
|
||||
#### 模型管理
|
||||
|
||||
456
backend/cmd/desktop/main.go
Normal file
456
backend/cmd/desktop/main.go
Normal file
@@ -0,0 +1,456 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
|
||||
"nex/embedfs"
|
||||
)
|
||||
|
||||
var (
|
||||
server *http.Server
|
||||
zapLogger *zap.Logger
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := 9826
|
||||
|
||||
if err := acquireSingleInstance(); err != nil {
|
||||
showError("Nex Gateway", "已有 Nex 实例运行")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer releaseSingleInstance()
|
||||
|
||||
if err := checkPortAvailable(port); err != nil {
|
||||
showError("Nex Gateway", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
showError("Nex Gateway", fmt.Sprintf("加载配置失败: %v", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
zapLogger, err = pkgLogger.New(pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
MaxBackups: cfg.Log.MaxBackups,
|
||||
MaxAge: cfg.Log.MaxAge,
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
showError("Nex Gateway", fmt.Sprintf("初始化日志失败: %v", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
|
||||
db, err := initDatabase(cfg)
|
||||
if err != nil {
|
||||
showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
defer closeDB(db)
|
||||
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
providerClient := provider.NewClient()
|
||||
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
r.Use(middleware.RequestID())
|
||||
r.Use(middleware.Recovery(zapLogger))
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupStaticFiles(r)
|
||||
|
||||
server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
setupSystray(port)
|
||||
}
|
||||
|
||||
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
dbDir := filepath.Dir(cfg.Database.Path)
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir()
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
goose.SetDialect("sqlite3")
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrationsDir() string {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
func closeDB(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
r.Any("/v1/*path", proxyHandler.HandleProxy)
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
providers.POST("", providerHandler.CreateProvider)
|
||||
providers.GET("/:id", providerHandler.GetProvider)
|
||||
providers.PUT("/:id", providerHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
models.POST("", modelHandler.CreateModel)
|
||||
models.GET("/:id", modelHandler.GetModel)
|
||||
models.PUT("/:id", modelHandler.UpdateModel)
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
if strings.HasSuffix(path, ".png") {
|
||||
return "image/png"
|
||||
}
|
||||
if strings.HasSuffix(path, ".ico") {
|
||||
return "image/x-icon"
|
||||
}
|
||||
if strings.HasSuffix(path, ".woff") || strings.HasSuffix(path, ".woff2") {
|
||||
return "font/woff2"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
}
|
||||
|
||||
func setupSystray(port int) {
|
||||
systray.Run(func() {
|
||||
icon, err := embedfs.Assets.ReadFile("assets/icon.png")
|
||||
if err != nil {
|
||||
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error()))
|
||||
}
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTitle("Nex Gateway")
|
||||
systray.SetTooltip("AI Gateway")
|
||||
|
||||
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||
systray.AddSeparator()
|
||||
mStatus := systray.AddMenuItem("状态: 运行中", "")
|
||||
mStatus.Disable()
|
||||
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
|
||||
mPort.Disable()
|
||||
systray.AddSeparator()
|
||||
mAbout := systray.AddMenuItem("关于", "")
|
||||
systray.AddSeparator()
|
||||
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
openBrowser(fmt.Sprintf("http://localhost:%d", port))
|
||||
case <-mAbout.ClickedCh:
|
||||
showAbout()
|
||||
case <-mQuit.ClickedCh:
|
||||
doShutdown()
|
||||
systray.Quit()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func doShutdown() {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("正在关闭服务器...")
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
if shutdownCancel != nil {
|
||||
shutdownCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func checkPortAvailable(port int) error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
||||
}
|
||||
ln.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
var lockFile *os.File
|
||||
|
||||
func acquireSingleInstance() error {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway.lock")
|
||||
|
||||
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return fmt.Errorf("已有实例运行")
|
||||
}
|
||||
|
||||
lockFile = f
|
||||
return nil
|
||||
}
|
||||
|
||||
func releaseSingleInstance() {
|
||||
if lockFile != nil {
|
||||
syscall.Flock(int(lockFile.Fd()), syscall.LOCK_UN)
|
||||
lockFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
case "linux":
|
||||
browsers := []string{"xdg-open", "google-chrome", "firefox"}
|
||||
for _, browser := range browsers {
|
||||
if _, err := exec.LookPath(browser); err == nil {
|
||||
cmd = exec.Command(browser, url)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cmd == nil {
|
||||
return fmt.Errorf("无法打开浏览器")
|
||||
}
|
||||
|
||||
return cmd.Start()
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, message, title)
|
||||
exec.Command("osascript", "-e", script).Run()
|
||||
case "windows":
|
||||
exec.Command("msg", "*", message).Run()
|
||||
case "linux":
|
||||
exec.Command("zenity", "--error", fmt.Sprintf("--title=%s", title), fmt.Sprintf("--text=%s", message)).Run()
|
||||
}
|
||||
}
|
||||
|
||||
func showAbout() {
|
||||
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, message)
|
||||
exec.Command("osascript", "-e", script).Run()
|
||||
case "windows":
|
||||
exec.Command("msg", "*", message).Run()
|
||||
case "linux":
|
||||
exec.Command("zenity", "--info", "--title=关于 Nex Gateway", fmt.Sprintf("--text=%s", message)).Run()
|
||||
}
|
||||
}
|
||||
69
backend/cmd/desktop/port_test.go
Normal file
69
backend/cmd/desktop/port_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCheckPortAvailable(t *testing.T) {
|
||||
port := 19826
|
||||
|
||||
err := checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口 %d 应该可用: %v", port, err)
|
||||
}
|
||||
|
||||
t.Log("端口可用测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err == nil {
|
||||
t.Fatal("端口被占用时应该返回错误")
|
||||
}
|
||||
|
||||
t.Log("端口占用检测测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
||||
port := 19828
|
||||
|
||||
listener, err := net.Listen("tcp", ":19828")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{}
|
||||
go server.Serve(listener)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
listener.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口关闭后应该可用: %v", err)
|
||||
}
|
||||
|
||||
t.Log("端口关闭后可用测试通过")
|
||||
}
|
||||
39
backend/cmd/desktop/singleton_test.go
Normal file
39
backend/cmd/desktop/singleton_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAcquireSingleInstance(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test.lock")
|
||||
|
||||
origLockFile := lockFile
|
||||
lockFile = nil
|
||||
defer func() { lockFile = origLockFile }()
|
||||
|
||||
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建锁文件: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
if err != nil {
|
||||
t.Fatalf("无法获取文件锁: %v", err)
|
||||
}
|
||||
defer syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
|
||||
|
||||
t.Log("单实例锁测试通过")
|
||||
}
|
||||
|
||||
func TestReleaseSingleInstance(t *testing.T) {
|
||||
lockFile = nil
|
||||
|
||||
releaseSingleInstance()
|
||||
|
||||
t.Log("释放空锁测试通过")
|
||||
}
|
||||
123
backend/cmd/desktop/static_test.go
Normal file
123
backend/cmd/desktop/static_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/embedfs"
|
||||
)
|
||||
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
|
||||
t.Run("API 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for JS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "application/javascript"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for CSS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.css", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "text/css"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Log("静态文件服务测试通过")
|
||||
}
|
||||
@@ -67,13 +67,25 @@ func main() {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
// 5. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
// 5. 初始化缓存
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
// 6. 创建 ConversionEngine
|
||||
// 6. 初始化统计缓冲
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
|
||||
// 7. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
// 8. 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
@@ -83,16 +95,16 @@ func main() {
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
// 7. 初始化 provider client
|
||||
// 9. 初始化 provider client
|
||||
providerClient := provider.NewClient()
|
||||
|
||||
// 8. 初始化 handler 层
|
||||
// 10. 初始化 handler 层
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
// 9. 创建 Gin 引擎
|
||||
// 11. 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
@@ -103,7 +115,7 @@ func main() {
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
|
||||
// 10. 启动服务器
|
||||
// 12. 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: formatAddr(cfg.Server.Port),
|
||||
Handler: r,
|
||||
@@ -131,6 +143,8 @@ func main() {
|
||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
||||
}
|
||||
|
||||
statsBuffer.Stop()
|
||||
|
||||
zapLogger.Info("服务器已关闭")
|
||||
}
|
||||
|
||||
@@ -142,14 +156,14 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
177
backend/go.mod
177
backend/go.mod
@@ -2,7 +2,15 @@ module nex/backend
|
||||
|
||||
go 1.26.2
|
||||
|
||||
replace nex/embedfs => ../embedfs
|
||||
|
||||
tool (
|
||||
github.com/golangci/golangci-lint/cmd/golangci-lint
|
||||
go.uber.org/mock/mockgen
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/getlantern/systray v1.2.2
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/go-playground/validator/v10 v10.30.2
|
||||
github.com/google/uuid v1.6.0
|
||||
@@ -11,60 +19,229 @@ require (
|
||||
github.com/spf13/pflag v1.0.10
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/mock v0.6.0
|
||||
go.uber.org/zap v1.27.1
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
nex/embedfs v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
require (
|
||||
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
|
||||
4d63.com/gochecknoglobals v0.2.2 // indirect
|
||||
github.com/4meepo/tagalign v1.4.2 // indirect
|
||||
github.com/Abirdcfly/dupword v0.1.3 // indirect
|
||||
github.com/Antonboom/errname v1.0.0 // indirect
|
||||
github.com/Antonboom/nilnil v1.0.1 // indirect
|
||||
github.com/Antonboom/testifylint v1.5.2 // indirect
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/Crocmagnon/fatcontext v0.7.1 // indirect
|
||||
github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect
|
||||
github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||
github.com/OpenPeeDeeP/depguard/v2 v2.2.1 // indirect
|
||||
github.com/alecthomas/go-check-sumtype v0.3.1 // indirect
|
||||
github.com/alexkohler/nakedret/v2 v2.0.5 // indirect
|
||||
github.com/alexkohler/prealloc v1.0.0 // indirect
|
||||
github.com/alingse/asasalint v0.0.11 // indirect
|
||||
github.com/alingse/nilnesserr v0.1.2 // indirect
|
||||
github.com/ashanbrown/forbidigo v1.6.0 // indirect
|
||||
github.com/ashanbrown/makezero v1.2.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bkielbasa/cyclop v1.2.3 // indirect
|
||||
github.com/blizzy78/varnamelen v0.8.0 // indirect
|
||||
github.com/bombsimon/wsl/v4 v4.5.0 // indirect
|
||||
github.com/breml/bidichk v0.3.2 // indirect
|
||||
github.com/breml/errchkjson v0.4.0 // indirect
|
||||
github.com/butuzov/ireturn v0.3.1 // indirect
|
||||
github.com/butuzov/mirror v1.3.0 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/catenacyber/perfsprint v0.8.2 // indirect
|
||||
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charithe/durationcheck v0.0.10 // indirect
|
||||
github.com/chavacava/garif v0.1.0 // indirect
|
||||
github.com/ckaznocha/intrange v0.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/curioswitch/go-reassign v0.3.0 // indirect
|
||||
github.com/daixiang0/gci v0.13.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/denis-tingaikin/go-header v0.5.0 // indirect
|
||||
github.com/ettle/strcase v0.2.0 // indirect
|
||||
github.com/fatih/color v1.18.0 // indirect
|
||||
github.com/fatih/structtag v1.2.0 // indirect
|
||||
github.com/firefart/nonamedreturns v1.0.5 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/fzipp/gocyclo v0.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||
github.com/ghostiam/protogetter v0.3.9 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-critic/go-critic v0.12.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/go-toolsmith/astcast v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astcopy v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astequal v1.2.0 // indirect
|
||||
github.com/go-toolsmith/astfmt v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astp v1.1.0 // indirect
|
||||
github.com/go-toolsmith/strparse v1.1.0 // indirect
|
||||
github.com/go-toolsmith/typep v1.1.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect
|
||||
github.com/gobwas/glob v0.2.3 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/gofrs/flock v0.12.1 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect
|
||||
github.com/golangci/go-printf-func-name v0.1.0 // indirect
|
||||
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
|
||||
github.com/golangci/golangci-lint v1.64.8 // indirect
|
||||
github.com/golangci/misspell v0.6.0 // indirect
|
||||
github.com/golangci/plugin-module-register v0.1.1 // indirect
|
||||
github.com/golangci/revgrep v0.8.0 // indirect
|
||||
github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/gordonklaus/ineffassign v0.1.0 // indirect
|
||||
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
|
||||
github.com/gostaticanalysis/comment v1.5.0 // indirect
|
||||
github.com/gostaticanalysis/forcetypeassert v0.2.0 // indirect
|
||||
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
|
||||
github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect
|
||||
github.com/hashicorp/go-version v1.7.0 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
github.com/hexops/gotextdiff v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jgautheron/goconst v1.7.1 // indirect
|
||||
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jjti/go-spancheck v0.6.4 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/julz/importas v0.2.0 // indirect
|
||||
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
|
||||
github.com/kisielk/errcheck v1.9.0 // indirect
|
||||
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/kulti/thelper v0.6.3 // indirect
|
||||
github.com/kunwardeep/paralleltest v1.0.10 // indirect
|
||||
github.com/lasiar/canonicalheader v1.1.2 // indirect
|
||||
github.com/ldez/exptostd v0.4.2 // indirect
|
||||
github.com/ldez/gomoddirectives v0.6.1 // indirect
|
||||
github.com/ldez/grignotin v0.9.0 // indirect
|
||||
github.com/ldez/tagliatelle v0.7.1 // indirect
|
||||
github.com/ldez/usetesting v0.4.2 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/leonklingele/grouper v1.1.2 // indirect
|
||||
github.com/macabu/inamedparam v0.1.3 // indirect
|
||||
github.com/maratori/testableexamples v1.0.0 // indirect
|
||||
github.com/maratori/testpackage v1.1.1 // indirect
|
||||
github.com/matoous/godox v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/mgechev/revive v1.7.0 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/moricho/tparallel v0.3.2 // indirect
|
||||
github.com/nakabonne/nestif v0.3.1 // indirect
|
||||
github.com/nishanths/exhaustive v0.12.0 // indirect
|
||||
github.com/nishanths/predeclared v0.2.2 // indirect
|
||||
github.com/nunnatsa/ginkgolinter v0.19.1 // indirect
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
|
||||
github.com/prometheus/client_golang v1.12.1 // indirect
|
||||
github.com/prometheus/client_model v0.2.0 // indirect
|
||||
github.com/prometheus/common v0.32.1 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
|
||||
github.com/quasilyte/gogrep v0.5.0 // indirect
|
||||
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect
|
||||
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/raeperd/recvcheck v0.2.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/ryancurrah/gomodguard v1.3.5 // indirect
|
||||
github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect
|
||||
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
|
||||
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
|
||||
github.com/securego/gosec/v2 v2.22.2 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sivchari/containedctx v1.0.3 // indirect
|
||||
github.com/sivchari/tenv v1.12.1 // indirect
|
||||
github.com/sonatard/noctx v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/sourcegraph/go-diff v0.7.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/cobra v1.9.1 // indirect
|
||||
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
|
||||
github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tdakkota/asciicheck v0.4.1 // indirect
|
||||
github.com/tetafro/godot v1.5.0 // indirect
|
||||
github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect
|
||||
github.com/timonwong/loggercheck v0.10.1 // indirect
|
||||
github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect
|
||||
github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
github.com/ultraware/funlen v0.2.0 // indirect
|
||||
github.com/ultraware/whitespace v0.2.0 // indirect
|
||||
github.com/uudashr/gocognit v1.2.0 // indirect
|
||||
github.com/uudashr/iface v1.3.1 // indirect
|
||||
github.com/xen0n/gosmopolitan v1.2.2 // indirect
|
||||
github.com/yagipy/maintidx v1.0.0 // indirect
|
||||
github.com/yeya24/promlinter v0.3.0 // indirect
|
||||
github.com/ykadowak/zerologlint v0.1.5 // indirect
|
||||
gitlab.com/bosi/decorder v0.4.2 // indirect
|
||||
go-simpler.org/musttag v0.13.0 // indirect
|
||||
go-simpler.org/sloglint v0.9.0 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated // indirect
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
honnef.co/go/tools v0.6.1 // indirect
|
||||
mvdan.cc/gofumpt v0.7.0 // indirect
|
||||
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
|
||||
)
|
||||
|
||||
913
backend/go.sum
913
backend/go.sum
File diff suppressed because it is too large
Load Diff
@@ -48,11 +48,3 @@ func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,3 +321,58 @@ func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEv
|
||||
}
|
||||
|
||||
var _ = json.Marshal
|
||||
|
||||
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, errors.New("decode embedding failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"text-embedding","input":"hello"}`)
|
||||
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, errors.New("decode rerank failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
|
||||
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"test":"data"}`)
|
||||
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
@@ -13,18 +13,20 @@ import (
|
||||
|
||||
// mockProtocolAdapter 模拟协议适配器
|
||||
type mockProtocolAdapter struct {
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
|
||||
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
|
||||
}
|
||||
|
||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
@@ -126,6 +128,9 @@ func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalM
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
if m.decodeEmbeddingReqFn != nil {
|
||||
return m.decodeEmbeddingReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{}, nil
|
||||
}
|
||||
|
||||
@@ -142,6 +147,9 @@ func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalE
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
if m.decodeRerankReqFn != nil {
|
||||
return m.decodeRerankReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -409,3 +409,25 @@ func TestDecodeResponse_Refusal(t *testing.T) {
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello back"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
assert.Len(t, assistantMsg.Content, 1)
|
||||
assert.Equal(t, "text", assistantMsg.Content[0].Type)
|
||||
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
|
||||
}
|
||||
|
||||
@@ -14,11 +14,3 @@ type Provider struct {
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,12 +9,19 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -33,11 +40,16 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
var result domain.Provider
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "p1", result.ID)
|
||||
assert.Contains(t, result.APIKey, "***")
|
||||
assert.Equal(t, "sk-test", result.APIKey)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -56,9 +68,13 @@ func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -72,7 +88,11 @@ func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -84,7 +104,12 @@ func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -97,7 +122,12 @@ func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -110,7 +140,15 @@ func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "p1",
|
||||
@@ -130,9 +168,12 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_GetModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -148,9 +189,13 @@ func TestModelHandler_GetModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -2,119 +2,36 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// ============ Mock 实现 ============
|
||||
|
||||
type mockRoutingService struct {
|
||||
result *domain.RouteResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockStatsService) Record(providerID, modelName string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return m.stats, nil
|
||||
}
|
||||
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
return m.aggrResult
|
||||
}
|
||||
|
||||
type mockProviderService struct {
|
||||
provider *domain.Provider
|
||||
providers []domain.Provider
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
return m.provider, m.err
|
||||
}
|
||||
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
|
||||
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockProviderService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockModelService struct {
|
||||
model *domain.Model
|
||||
models []domain.Model
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockModelService) Create(model *domain.Model) error {
|
||||
if m.err == nil {
|
||||
model.ID = "mock-uuid-1234"
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||
return m.model, m.err
|
||||
}
|
||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
return m.models, m.err
|
||||
}
|
||||
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
|
||||
return []domain.Model{}, nil
|
||||
}
|
||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "p1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -127,12 +44,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
providers: []domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -142,14 +62,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -160,10 +83,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Model Handler 测试 ============
|
||||
|
||||
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "m1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -176,12 +101,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_ListModels(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
models: []domain.Model{
|
||||
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
|
||||
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||
}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -198,9 +126,12 @@ func TestModelHandler_ListModels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -217,7 +148,15 @@ func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
@@ -238,9 +177,13 @@ func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -257,14 +200,15 @@ func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
// ============ Stats Handler 测试 ============
|
||||
|
||||
func TestStatsHandler_GetStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
}, nil)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -275,7 +219,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -286,14 +234,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
},
|
||||
aggrResult: []map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
},
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
}, nil)
|
||||
mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
})
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -303,8 +254,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ writeError 测试 ============
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -333,12 +282,13 @@ func formatMapErrors(errs map[string]string) string {
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// ============ 错误类型判断测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
err: appErrors.ErrConflict,
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -354,3 +304,158 @@ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "nonexistent",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商不存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
@@ -66,7 +66,6 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
@@ -85,7 +84,7 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
@@ -131,7 +130,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
writeError(c, err)
|
||||
return
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -50,6 +50,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
// ProviderClient 供应商客户端接口
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
type ProviderClient interface {
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
||||
|
||||
@@ -2,6 +2,8 @@ package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
|
||||
|
||||
// ModelRepository 模型数据仓库接口
|
||||
type ModelRepository interface {
|
||||
Create(model *domain.Model) error
|
||||
|
||||
@@ -2,6 +2,8 @@ package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
|
||||
|
||||
// ProviderRepository 供应商数据仓库接口
|
||||
type ProviderRepository interface {
|
||||
Create(provider *domain.Provider) error
|
||||
@@ -9,7 +11,4 @@ type ProviderRepository interface {
|
||||
List() ([]domain.Provider, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
// 统一模型 ID 相关方法
|
||||
ListEnabledModels() ([]domain.Model, error)
|
||||
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
|
||||
}
|
||||
|
||||
@@ -71,25 +71,6 @@ func (r *providerRepository) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListEnabledModels 返回所有启用的模型(关联启用的供应商)
|
||||
func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) {
|
||||
var models []domain.Model
|
||||
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
|
||||
Where("models.enabled = ? AND providers.enabled = ?", true, true).
|
||||
Find(&models).Error
|
||||
return models, err
|
||||
}
|
||||
|
||||
// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型
|
||||
func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||
var model domain.Model
|
||||
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
func toDomainProvider(p *config.Provider) domain.Provider {
|
||||
return domain.Provider{
|
||||
ID: p.ID,
|
||||
|
||||
@@ -5,28 +5,16 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
// 关闭数据库连接以便 TempDir 清理
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
return testHelpers.SetupTestDB(t)
|
||||
}
|
||||
|
||||
// ============ ProviderRepository 测试 ============
|
||||
@@ -88,7 +76,7 @@ func TestProviderRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})
|
||||
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}))
|
||||
|
||||
err := repo.Update("p1", map[string]interface{}{"name": "New"})
|
||||
require.NoError(t, err)
|
||||
@@ -109,7 +97,7 @@ func TestProviderRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
err := repo.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -129,17 +117,21 @@ func TestProviderRepository_Delete_NotFound(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Create(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestModelRepository_GetByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := repo.GetByID("m1")
|
||||
require.NoError(t, err)
|
||||
@@ -149,9 +141,11 @@ func TestModelRepository_GetByID(t *testing.T) {
|
||||
|
||||
func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
@@ -162,9 +156,11 @@ func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
|
||||
|
||||
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
// Wrong provider_id
|
||||
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
|
||||
@@ -181,11 +177,14 @@ func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
|
||||
|
||||
func TestModelRepository_List(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p2", Name: "Test2", APIKey: "key", BaseURL: "https://test2.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}))
|
||||
|
||||
all, err := repo.List("")
|
||||
require.NoError(t, err)
|
||||
@@ -246,9 +245,11 @@ func TestModelRepository_ListEnabled(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
err := repo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, err)
|
||||
@@ -259,9 +260,11 @@ func TestModelRepository_Update(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
|
||||
|
||||
err := repo.Delete("m1")
|
||||
require.NoError(t, err)
|
||||
@@ -293,10 +296,32 @@ func TestStatsRepository_Query(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewStatsRepository(db)
|
||||
|
||||
repo.Record("p1", "gpt-4")
|
||||
require.NoError(t, repo.Record("p1", "gpt-4"))
|
||||
// 注意:当前 schema 只有 date 字段有唯一约束
|
||||
// 所以同一 provider + model 只能有一条记录
|
||||
stats, err := repo.Query("p1", "", nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
}
|
||||
|
||||
func TestModelRepository_List_EmptyResult(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
result, err := repo.List("")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result)
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
func TestProviderRepository_List_EmptyResult(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
result, err := repo.List()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result)
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@ import (
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
|
||||
|
||||
// StatsRepository 统计数据仓库接口
|
||||
type StatsRepository interface {
|
||||
Record(providerID, modelName string) error
|
||||
BatchUpdate(providerID, modelName string, date time.Time, delta int) error
|
||||
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,28 @@ func (r *statsRepository) Record(providerID, modelName string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats config.UsageStats
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, date).First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return tx.Create(&config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: delta,
|
||||
Date: date,
|
||||
}).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Model(&stats).
|
||||
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
var stats []config.UsageStats
|
||||
query := r.db.Model(&config.UsageStats{})
|
||||
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
|
||||
|
||||
// ModelService 模型服务接口
|
||||
type ModelService interface {
|
||||
Create(model *domain.Model) error
|
||||
|
||||
@@ -11,27 +11,30 @@ import (
|
||||
type modelService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
|
||||
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService {
|
||||
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache}
|
||||
}
|
||||
|
||||
func (s *modelService) Create(model *domain.Model) error {
|
||||
// 校验供应商存在
|
||||
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
// 联合唯一校验:同一供应商下 model_name 不重复
|
||||
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 自动生成 UUID 作为 id
|
||||
model.ID = uuid.New().String()
|
||||
model.Enabled = true
|
||||
return s.modelRepo.Create(model)
|
||||
err := s.modelRepo.Create(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.SetModel(model)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelService) Get(id string) (*domain.Model, error) {
|
||||
@@ -47,20 +50,17 @@ func (s *modelService) ListEnabled() ([]domain.Model, error) {
|
||||
}
|
||||
|
||||
func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||
// 获取当前模型
|
||||
current, err := s.modelRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
// 如果更新 provider_id,校验新供应商存在
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
if _, err := s.providerRepo.GetByID(providerID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
}
|
||||
|
||||
// 确定更新后的 provider_id 和 model_name
|
||||
newProviderID := current.ProviderID
|
||||
if v, ok := updates["provider_id"].(string); ok {
|
||||
newProviderID = v
|
||||
@@ -70,18 +70,37 @@ func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||
newModelName = v
|
||||
}
|
||||
|
||||
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
|
||||
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.modelRepo.Update(id, updates)
|
||||
err = s.modelRepo.Update(id, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.InvalidateModel(current.ProviderID, current.ModelName)
|
||||
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||
s.cache.InvalidateModel(newProviderID, newModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(id string) error {
|
||||
return s.modelRepo.Delete(id)
|
||||
model, err := s.modelRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
err = s.modelRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.InvalidateModel(model.ProviderID, model.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
|
||||
|
||||
@@ -2,10 +2,12 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
|
||||
|
||||
// ProviderService 供应商服务接口
|
||||
type ProviderService interface {
|
||||
Create(provider *domain.Provider) error
|
||||
Get(id string, maskKey bool) (*domain.Provider, error)
|
||||
Get(id string) (*domain.Provider, error)
|
||||
List() ([]domain.Provider, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
|
||||
@@ -13,56 +13,56 @@ import (
|
||||
type providerService struct {
|
||||
providerRepo repository.ProviderRepository
|
||||
modelRepo repository.ModelRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
|
||||
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache}
|
||||
}
|
||||
|
||||
func (s *providerService) Create(provider *domain.Provider) error {
|
||||
// 校验 provider_id 字符集
|
||||
if err := modelid.ValidateProviderID(provider.ID); err != nil {
|
||||
return appErrors.ErrInvalidProviderID
|
||||
}
|
||||
provider.Enabled = true
|
||||
err := s.providerRepo.Create(provider)
|
||||
if err != nil && isUniqueConstraintError(err) {
|
||||
return appErrors.ErrConflict
|
||||
if err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return appErrors.ErrConflict
|
||||
}
|
||||
return err
|
||||
}
|
||||
return err
|
||||
s.cache.SetProvider(provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
provider, err := s.providerRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if maskKey {
|
||||
provider.MaskAPIKey()
|
||||
}
|
||||
return provider, nil
|
||||
func (s *providerService) Get(id string) (*domain.Provider, error) {
|
||||
return s.providerRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *providerService) List() ([]domain.Provider, error) {
|
||||
providers, err := s.providerRepo.List()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range providers {
|
||||
providers[i].MaskAPIKey()
|
||||
}
|
||||
return providers, nil
|
||||
return s.providerRepo.List()
|
||||
}
|
||||
|
||||
func (s *providerService) Update(id string, updates map[string]interface{}) error {
|
||||
if _, ok := updates["id"]; ok {
|
||||
return appErrors.ErrImmutableField
|
||||
}
|
||||
return s.providerRepo.Update(id, updates)
|
||||
err := s.providerRepo.Update(id, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.InvalidateProvider(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *providerService) Delete(id string) error {
|
||||
return s.providerRepo.Delete(id)
|
||||
err := s.providerRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.InvalidateProvider(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
|
||||
|
||||
134
backend/internal/service/routing_cache.go
Normal file
134
backend/internal/service/routing_cache.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type RoutingCache struct {
|
||||
providers sync.Map
|
||||
models sync.Map
|
||||
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewRoutingCache(
|
||||
modelRepo repository.ModelRepository,
|
||||
providerRepo repository.ProviderRepository,
|
||||
logger *zap.Logger,
|
||||
) *RoutingCache {
|
||||
return &RoutingCache{
|
||||
modelRepo: modelRepo,
|
||||
providerRepo: providerRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
}
|
||||
|
||||
provider, err := c.providerRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
}
|
||||
|
||||
c.providers.Store(id, provider)
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) {
|
||||
key := providerID + "/" + modelName
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
}
|
||||
|
||||
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
}
|
||||
|
||||
c.models.Store(key, model)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (c *RoutingCache) SetProvider(provider *domain.Provider) {
|
||||
c.providers.Store(provider.ID, provider)
|
||||
}
|
||||
|
||||
func (c *RoutingCache) SetModel(model *domain.Model) {
|
||||
key := model.ProviderID + "/" + model.ModelName
|
||||
c.models.Store(key, model)
|
||||
}
|
||||
|
||||
func (c *RoutingCache) InvalidateProvider(id string) {
|
||||
c.providers.Delete(id)
|
||||
c.invalidateModelsByProvider(id)
|
||||
c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id))
|
||||
}
|
||||
|
||||
func (c *RoutingCache) InvalidateModel(providerID, modelName string) {
|
||||
key := providerID + "/" + modelName
|
||||
c.models.Delete(key)
|
||||
c.logger.Debug("Model 缓存失效",
|
||||
zap.String("provider_id", providerID),
|
||||
zap.String("model_name", modelName))
|
||||
}
|
||||
|
||||
func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
|
||||
prefix := providerID + "/"
|
||||
count := 0
|
||||
c.models.Range(func(key, value interface{}) bool {
|
||||
if strings.HasPrefix(key.(string), prefix) {
|
||||
c.models.Delete(key)
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
c.logger.Debug("清除 Provider 相关 Model 缓存",
|
||||
zap.String("provider_id", providerID),
|
||||
zap.Int("count", count))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoutingCache) Preload() error {
|
||||
providers, err := c.providerRepo.List()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range providers {
|
||||
c.providers.Store(providers[i].ID, &providers[i])
|
||||
}
|
||||
|
||||
models, err := c.modelRepo.List("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range models {
|
||||
key := models[i].ProviderID + "/" + models[i].ModelName
|
||||
c.models.Store(key, &models[i])
|
||||
}
|
||||
|
||||
c.logger.Info("缓存预热完成",
|
||||
zap.Int("providers", len(providers)),
|
||||
zap.Int("models", len(models)))
|
||||
return nil
|
||||
}
|
||||
273
backend/internal/service/routing_cache_test.go
Normal file
273
backend/internal/service/routing_cache_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockModelRepo struct {
|
||||
models map[string]*domain.Model
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Create(model *domain.Model) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||
key := providerID + "/" + modelName
|
||||
if model, ok := m.models[key]; ok {
|
||||
return model, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
|
||||
var result []domain.Model
|
||||
for _, model := range m.models {
|
||||
if providerID == "" || model.ProviderID == providerID {
|
||||
result = append(result, *model)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockProviderRepo struct {
|
||||
providers map[string]*domain.Provider
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
|
||||
if provider, ok := m.providers[id]; ok {
|
||||
return provider, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
|
||||
var result []domain.Provider
|
||||
for _, provider := range m.providers {
|
||||
result = append(result, *provider)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
provider, err := cache.GetProvider("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai", provider.ID)
|
||||
|
||||
provider2, err := cache.GetProvider("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, provider, provider2)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetProvider("notexist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
model, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", model.ModelName)
|
||||
|
||||
model2, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model, model2)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "notexist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingCache_DoubleCheck(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := cache.GetProvider("openai")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
|
||||
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
_, err = cache.GetModel("openai", "gpt-3.5")
|
||||
require.NoError(t, err)
|
||||
_, err = cache.GetModel("anthropic", "claude")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.InvalidateProvider("openai")
|
||||
|
||||
var openaiCount, anthropicCount int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
if key.(string) == "anthropic/claude" {
|
||||
anthropicCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 0, openaiCount)
|
||||
assert.Equal(t, 1, anthropicCount)
|
||||
}
|
||||
|
||||
func TestRoutingCache_InvalidateModel(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.InvalidateModel("openai", "gpt-4")
|
||||
|
||||
var count int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestRoutingCache_Preload_Success(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
err := cache.Preload()
|
||||
require.NoError(t, err)
|
||||
|
||||
var providerCount, modelCount int
|
||||
cache.providers.Range(func(key, value interface{}) bool {
|
||||
providerCount++
|
||||
return true
|
||||
})
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
modelCount++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 1, providerCount)
|
||||
assert.Equal(t, 1, modelCount)
|
||||
}
|
||||
|
||||
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetProvider("openai")
|
||||
_, _ = cache.GetModel("openai", "gpt-4")
|
||||
cache.InvalidateProvider("openai")
|
||||
cache.InvalidateModel("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
|
||||
|
||||
// RoutingService 路由服务接口
|
||||
type RoutingService interface {
|
||||
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
|
||||
|
||||
@@ -4,20 +4,18 @@ import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type routingService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
|
||||
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
func NewRoutingService(cache *RoutingCache) RoutingService {
|
||||
return &routingService{cache: cache}
|
||||
}
|
||||
|
||||
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
model, err := s.cache.GetModel(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrModelNotFound
|
||||
}
|
||||
@@ -26,7 +24,7 @@ func (s *routingService) RouteByModelName(providerID, modelName string) (*domain
|
||||
return nil, appErrors.ErrModelDisabled
|
||||
}
|
||||
|
||||
provider, err := s.providerRepo.GetByID(model.ProviderID)
|
||||
provider, err := s.cache.GetProvider(model.ProviderID)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
@@ -14,14 +14,15 @@ func TestProviderService_Update(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
|
||||
err := svc.Update("p1", map[string]interface{}{"name": "Updated"})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := svc.Get("p1", false)
|
||||
result, err := svc.Get("p1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", result.Name)
|
||||
}
|
||||
@@ -30,7 +31,8 @@ func TestProviderService_Update_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
|
||||
assert.Error(t, err)
|
||||
@@ -40,9 +42,10 @@ func TestModelService_Get(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
@@ -55,9 +58,10 @@ func TestModelService_Update(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
@@ -73,9 +77,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
@@ -87,9 +92,10 @@ func TestModelService_Delete(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
@@ -104,7 +110,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Delete("nonexistent")
|
||||
assert.Error(t, err)
|
||||
@@ -112,7 +119,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
|
||||
|
||||
func TestStatsService_Aggregate_Default(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
buffer := NewStatsBuffer(statsRepo, nil)
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
@@ -133,7 +141,8 @@ func TestModelService_Update_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
|
||||
assert.Error(t, err)
|
||||
|
||||
@@ -3,14 +3,16 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
@@ -18,18 +20,14 @@ import (
|
||||
|
||||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
return testHelpers.SetupTestDB(t)
|
||||
}
|
||||
|
||||
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
|
||||
t.Helper()
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
|
||||
}
|
||||
|
||||
// ============ RoutingService - RouteByModelName 测试 ============
|
||||
@@ -38,11 +36,11 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
// 创建供应商和模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
@@ -52,9 +50,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
|
||||
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "nonexistent-model")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
@@ -64,12 +61,12 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
// 创建启用的供应商和禁用的模型
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
||||
@@ -79,12 +76,12 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
// 创建启用的供应商和模型,然后禁用供应商
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
|
||||
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
||||
@@ -96,20 +93,19 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证返回的 model 拥有有效的 UUID
|
||||
assert.NotEmpty(t, model.ID)
|
||||
_, err = uuid.Parse(model.ID)
|
||||
assert.NoError(t, err, "model.ID should be a valid UUID")
|
||||
|
||||
// 通过 Get 验证持久化
|
||||
stored, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.ID, stored.ID)
|
||||
@@ -120,15 +116,15 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
|
||||
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err = svc.Create(model2)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
@@ -138,7 +134,8 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
@@ -151,7 +148,8 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -162,7 +160,8 @@ func TestProviderService_Create_ValidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -177,10 +176,11 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
|
||||
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
@@ -202,7 +202,8 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent-id", map[string]interface{}{
|
||||
"model_name": "gpt-4",
|
||||
@@ -214,9 +215,10 @@ func TestModelService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
@@ -239,7 +241,8 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -256,7 +259,8 @@ func TestProviderService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewProviderService(repo, modelRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
@@ -268,7 +272,233 @@ func TestProviderService_Update_Success(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.Get("openai", false)
|
||||
updated, err := svc.Get("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OpenAI Updated", updated.Name)
|
||||
}
|
||||
|
||||
// ============ StatsService - Aggregate ByModel 测试 ============
|
||||
|
||||
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stats []domain.UsageStats
|
||||
expected []map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "multiple providers with same model name",
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
|
||||
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty providerID",
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty result set",
|
||||
stats: []domain.UsageStats{},
|
||||
expected: []map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "model")
|
||||
|
||||
assert.Len(t, result, len(tt.expected))
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, r := range result {
|
||||
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
|
||||
assert.Equal(t, exp["request_count"], r["request_count"])
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected result not found: %v", exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ StatsService - Aggregate ByDate 测试 ============
|
||||
|
||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stats []domain.UsageStats
|
||||
expected []map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "normal date grouping",
|
||||
stats: []domain.UsageStats{
|
||||
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
|
||||
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
|
||||
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"date": "2024-01-01", "request_count": 15},
|
||||
{"date": "2024-01-02", "request_count": 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zero-value time",
|
||||
stats: []domain.UsageStats{
|
||||
{Date: time.Time{}, RequestCount: 10},
|
||||
{Date: time.Time{}, RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"date": "0001-01-01", "request_count": 15},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty result set",
|
||||
stats: []domain.UsageStats{},
|
||||
expected: []map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "date")
|
||||
|
||||
assert.Len(t, result, len(tt.expected))
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, r := range result {
|
||||
if r["date"] == exp["date"] {
|
||||
assert.Equal(t, exp["request_count"], r["request_count"])
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected result not found: %v", exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ ProviderService - isUniqueConstraintError 测试 ============
|
||||
|
||||
func TestProviderService_isUniqueConstraintError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "UNIQUE constraint failed",
|
||||
err: errors.New("UNIQUE constraint failed"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate key value",
|
||||
err: errors.New("duplicate key value"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "UNIQUE constraint case insensitive",
|
||||
err: errors.New("unique constraint violation"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
err: errors.New("some other error"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isUniqueConstraintError(tt.err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ ProviderService - List API Key 测试 ============
|
||||
|
||||
func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
|
||||
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
|
||||
require.NoError(t, svc.Create(provider1))
|
||||
require.NoError(t, svc.Create(provider2))
|
||||
|
||||
providers, err := svc.List()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 2)
|
||||
|
||||
expectedKeys := map[string]string{
|
||||
"openai": "sk-1234567890",
|
||||
"anthropic": "sk-anthropic1234",
|
||||
}
|
||||
for _, p := range providers {
|
||||
assert.NotContains(t, p.APIKey, "***")
|
||||
assert.Equal(t, expectedKeys[p.ID], p.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelService_ConcurrentCreate(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
results := make(chan error, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
go func() {
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
results <- svc.Create(model)
|
||||
}()
|
||||
}
|
||||
|
||||
err1 := <-results
|
||||
err2 := <-results
|
||||
|
||||
successCount := 0
|
||||
errorCount := 0
|
||||
for _, err := range []error{err1, err2} {
|
||||
if err == nil {
|
||||
successCount++
|
||||
} else {
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, successCount)
|
||||
assert.Equal(t, 1, errorCount)
|
||||
}
|
||||
|
||||
169
backend/internal/service/stats_buffer.go
Normal file
169
backend/internal/service/stats_buffer.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/repository"
|
||||
)
|
||||
|
||||
type StatsBuffer struct {
|
||||
counters sync.Map
|
||||
|
||||
flushInterval time.Duration
|
||||
flushThreshold int
|
||||
totalCount atomic.Int64
|
||||
|
||||
statsRepo repository.StatsRepository
|
||||
logger *zap.Logger
|
||||
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
type StatsBufferOption func(*StatsBuffer)
|
||||
|
||||
func WithFlushInterval(d time.Duration) StatsBufferOption {
|
||||
return func(b *StatsBuffer) {
|
||||
b.flushInterval = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithFlushThreshold(threshold int) StatsBufferOption {
|
||||
return func(b *StatsBuffer) {
|
||||
b.flushThreshold = threshold
|
||||
}
|
||||
}
|
||||
|
||||
func NewStatsBuffer(
|
||||
statsRepo repository.StatsRepository,
|
||||
logger *zap.Logger,
|
||||
opts ...StatsBufferOption,
|
||||
) *StatsBuffer {
|
||||
b := &StatsBuffer{
|
||||
statsRepo: statsRepo,
|
||||
logger: logger,
|
||||
flushInterval: 5 * time.Second,
|
||||
flushThreshold: 100,
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(b)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Increment(providerID, modelName string) {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
key := providerID + "/" + modelName + "/" + today
|
||||
|
||||
var counter *int64
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter = v.(*int64)
|
||||
} else {
|
||||
val := int64(0)
|
||||
counter = &val
|
||||
actual, loaded := b.counters.LoadOrStore(key, counter)
|
||||
if loaded {
|
||||
counter = actual.(*int64)
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddInt64(counter, 1)
|
||||
|
||||
if b.totalCount.Add(1) >= int64(b.flushThreshold) {
|
||||
go b.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Start() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(b.flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
b.flush()
|
||||
case <-b.stopCh:
|
||||
b.flush()
|
||||
close(b.doneCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Stop() {
|
||||
close(b.stopCh)
|
||||
<-b.doneCh
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) flush() {
|
||||
type statEntry struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date string
|
||||
count int64
|
||||
}
|
||||
|
||||
var entries []statEntry
|
||||
b.counters.Range(func(key, value interface{}) bool {
|
||||
keyStr := key.(string)
|
||||
parts := strings.Split(keyStr, "/")
|
||||
if len(parts) != 3 {
|
||||
return true
|
||||
}
|
||||
|
||||
counter := value.(*int64)
|
||||
count := atomic.SwapInt64(counter, 0)
|
||||
|
||||
if count > 0 {
|
||||
entries = append(entries, statEntry{
|
||||
providerID: parts[0],
|
||||
modelName: parts[1],
|
||||
date: parts[2],
|
||||
count: count,
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
success := 0
|
||||
for _, entry := range entries {
|
||||
date, _ := time.Parse("2006-01-02", entry.date)
|
||||
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
|
||||
if err != nil {
|
||||
b.logger.Error("批量更新统计失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
zap.String("model_name", entry.modelName),
|
||||
zap.Int64("count", entry.count),
|
||||
zap.Error(err))
|
||||
|
||||
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter := v.(*int64)
|
||||
atomic.AddInt64(counter, entry.count)
|
||||
}
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
|
||||
b.totalCount.Store(0)
|
||||
b.logger.Debug("统计刷新完成",
|
||||
zap.Int("total", len(entries)),
|
||||
zap.Int("success", success),
|
||||
zap.Int("failed", len(entries)-success))
|
||||
}
|
||||
251
backend/internal/service/stats_buffer_test.go
Normal file
251
backend/internal/service/stats_buffer_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockStatsRepo struct {
|
||||
records []struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date time.Time
|
||||
delta int
|
||||
}
|
||||
fail bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) Record(providerID, modelName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.fail {
|
||||
return errors.New("db error")
|
||||
}
|
||||
m.records = append(m.records, struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date time.Time
|
||||
delta int
|
||||
}{providerID, modelName, date, delta})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestStatsBuffer_Increment(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-3.5")
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count += atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(100), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_LoadOrStore(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var counterCount int
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counterCount++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 1, counterCount)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FlushByInterval(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(100*time.Millisecond))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FlushByThreshold(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushThreshold(10))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_SwapInt64(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
var beforeCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
beforeCount = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), beforeCount)
|
||||
|
||||
buffer.flush()
|
||||
|
||||
var afterCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
afterCount = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(0), afterCount)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FailRetry(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{fail: true}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
buffer.flush()
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
count = atomic.LoadInt64(counter)
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_Stop(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(10*time.Second))
|
||||
|
||||
buffer.Start()
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
start := time.Now()
|
||||
buffer.Stop()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 1*time.Second)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(50*time.Millisecond))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
totalDelta := 0
|
||||
for _, r := range statsRepo.records {
|
||||
totalDelta += r.delta
|
||||
}
|
||||
statsRepo.mu.Unlock()
|
||||
|
||||
assert.Equal(t, 100, totalDelta)
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
|
||||
|
||||
// StatsService 统计服务接口
|
||||
type StatsService interface {
|
||||
Record(providerID, modelName string) error
|
||||
|
||||
@@ -10,14 +10,16 @@ import (
|
||||
|
||||
type statsService struct {
|
||||
statsRepo repository.StatsRepository
|
||||
buffer *StatsBuffer
|
||||
}
|
||||
|
||||
func NewStatsService(statsRepo repository.StatsRepository) StatsService {
|
||||
return &statsService{statsRepo: statsRepo}
|
||||
func NewStatsService(statsRepo repository.StatsRepository, buffer *StatsBuffer) StatsService {
|
||||
return &statsService{statsRepo: statsRepo, buffer: buffer}
|
||||
}
|
||||
|
||||
func (s *statsService) Record(providerID, modelName string) error {
|
||||
return s.statsRepo.Record(providerID, modelName)
|
||||
s.buffer.Increment(providerID, modelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
|
||||
193
backend/tests/config/config_test.go
Normal file
193
backend/tests/config/config_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
)
|
||||
|
||||
func TestLoadConfig_DefaultValues(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
cfg, err := config.LoadConfigFromPath(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
|
||||
assert.Equal(t, "info", cfg.Log.Level)
|
||||
assert.Equal(t, 100, cfg.Log.MaxSize)
|
||||
assert.Equal(t, 10, cfg.Log.MaxBackups)
|
||||
assert.Equal(t, 30, cfg.Log.MaxAge)
|
||||
assert.True(t, cfg.Log.Compress)
|
||||
}
|
||||
|
||||
func TestLoadConfig_EnvOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
t.Setenv("NEX_SERVER_PORT", "9000")
|
||||
t.Setenv("NEX_LOG_LEVEL", "debug")
|
||||
t.Setenv("NEX_DATABASE_MAX_IDLE_CONNS", "20")
|
||||
|
||||
cfg, err := config.LoadConfigFromPath(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 9000, cfg.Server.Port)
|
||||
assert.Equal(t, "debug", cfg.Log.Level)
|
||||
assert.Equal(t, 20, cfg.Database.MaxIdleConns)
|
||||
}
|
||||
|
||||
func TestLoadConfig_YAMLFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
yamlContent := `
|
||||
server:
|
||||
port: 8080
|
||||
read_timeout: 60s
|
||||
write_timeout: 60s
|
||||
database:
|
||||
path: /custom/path.db
|
||||
max_idle_conns: 5
|
||||
max_open_conns: 50
|
||||
conn_max_lifetime: 2h
|
||||
log:
|
||||
level: warn
|
||||
path: /custom/log
|
||||
max_size: 200
|
||||
max_backups: 5
|
||||
max_age: 7
|
||||
compress: false
|
||||
`
|
||||
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := config.LoadConfigFromPath(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 8080, cfg.Server.Port)
|
||||
assert.Equal(t, 60*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 60*time.Second, cfg.Server.WriteTimeout)
|
||||
assert.Equal(t, "/custom/path.db", cfg.Database.Path)
|
||||
assert.Equal(t, 5, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 50, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 2*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
assert.Equal(t, "warn", cfg.Log.Level)
|
||||
assert.Equal(t, "/custom/log", cfg.Log.Path)
|
||||
assert.Equal(t, 200, cfg.Log.MaxSize)
|
||||
assert.Equal(t, 5, cfg.Log.MaxBackups)
|
||||
assert.Equal(t, 7, cfg.Log.MaxAge)
|
||||
assert.False(t, cfg.Log.Compress)
|
||||
}
|
||||
|
||||
func TestLoadConfig_PriorityChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
yamlContent := `
|
||||
server:
|
||||
port: 8080
|
||||
log:
|
||||
level: warn
|
||||
`
|
||||
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Setenv("NEX_SERVER_PORT", "9000")
|
||||
|
||||
originalArgs := os.Args
|
||||
defer func() { os.Args = originalArgs }()
|
||||
os.Args = []string{"test", "--server-port", "9999"}
|
||||
|
||||
cfg, err := config.LoadConfigFromPath(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 9999, cfg.Server.Port, "CLI should override ENV and YAML")
|
||||
assert.Equal(t, "warn", cfg.Log.Level, "YAML value should be used when no CLI/ENV override")
|
||||
}
|
||||
|
||||
func TestLoadConfig_AutoCreate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err), "config file should not exist before load")
|
||||
|
||||
cfg, err := config.LoadConfigFromPath(configPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, 9826, cfg.Server.Port, "should load with default values")
|
||||
}
|
||||
|
||||
func TestSaveAndLoadConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
configPath := filepath.Join(nexDir, "config.yaml")
|
||||
|
||||
originalConfig, err := os.ReadFile(configPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
defer func() {
|
||||
if originalConfig != nil {
|
||||
_ = os.WriteFile(configPath, originalConfig, 0644)
|
||||
}
|
||||
}()
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Port: 7777,
|
||||
ReadTimeout: 45 * time.Second,
|
||||
WriteTimeout: 45 * time.Second,
|
||||
},
|
||||
Database: config.DatabaseConfig{
|
||||
Path: filepath.Join(tmpDir, "test.db"),
|
||||
MaxIdleConns: 15,
|
||||
MaxOpenConns: 150,
|
||||
ConnMaxLifetime: 2 * time.Hour,
|
||||
},
|
||||
Log: config.LogConfig{
|
||||
Level: "debug",
|
||||
Path: filepath.Join(tmpDir, "log"),
|
||||
MaxSize: 50,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 14,
|
||||
Compress: false,
|
||||
},
|
||||
}
|
||||
|
||||
err = config.SaveConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
loaded, err := config.LoadConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Server.Port, loaded.Server.Port)
|
||||
assert.Equal(t, cfg.Server.ReadTimeout, loaded.Server.ReadTimeout)
|
||||
assert.Equal(t, cfg.Server.WriteTimeout, loaded.Server.WriteTimeout)
|
||||
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
|
||||
assert.Equal(t, cfg.Database.MaxOpenConns, loaded.Database.MaxOpenConns)
|
||||
assert.Equal(t, cfg.Database.ConnMaxLifetime, loaded.Database.ConnMaxLifetime)
|
||||
assert.Equal(t, cfg.Log.Level, loaded.Log.Level)
|
||||
assert.Equal(t, cfg.Log.MaxSize, loaded.Log.MaxSize)
|
||||
assert.Equal(t, cfg.Log.MaxBackups, loaded.Log.MaxBackups)
|
||||
assert.Equal(t, cfg.Log.MaxAge, loaded.Log.MaxAge)
|
||||
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
|
||||
}
|
||||
@@ -7,49 +7,40 @@ import (
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SetupTestDB initializes an in-memory SQLite database with auto-migration.
|
||||
// Uses :memory: mode with MaxOpenConns(1) to ensure all operations share the
|
||||
// same connection, avoiding "database is closed" errors from connection pool.
|
||||
// Enables foreign key constraints for SQLite.
|
||||
func SetupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:?_foreign_keys=on"), &gorm.Config{})
|
||||
assert.NoError(t, err, "failed to open test database")
|
||||
require.NoError(t, err, "failed to open test database")
|
||||
|
||||
// 限制为单连接,确保 :memory: 数据库不被连接池丢弃
|
||||
sqlDB, err := db.DB()
|
||||
assert.NoError(t, err, "failed to get underlying sql.DB")
|
||||
require.NoError(t, err, "failed to get underlying sql.DB")
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetConnMaxLifetime(0)
|
||||
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
assert.NoError(t, err, "failed to auto-migrate test database")
|
||||
require.NoError(t, err, "failed to auto-migrate test database")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// CleanupTestDB closes the database after a brief delay to allow async
|
||||
// goroutines (e.g. stats recording) to finish.
|
||||
func CleanupTestDB(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
// 等待异步 goroutine(如 statsService.Record)完成
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
assert.NoError(t, err, "failed to get underlying sql.DB")
|
||||
require.NoError(t, err, "failed to get underlying sql.DB")
|
||||
|
||||
err = sqlDB.Close()
|
||||
assert.NoError(t, err, "failed to close test database")
|
||||
require.NoError(t, err, "failed to close test database")
|
||||
}
|
||||
|
||||
// CreateTestProvider creates a test provider and returns it.
|
||||
func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
|
||||
t.Helper()
|
||||
|
||||
@@ -62,13 +53,11 @@ func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
|
||||
}
|
||||
|
||||
err := db.Create(&provider).Error
|
||||
assert.NoError(t, err, "failed to create test provider")
|
||||
require.NoError(t, err, "failed to create test provider")
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
// CreateTestModel creates a test model and returns it.
|
||||
// Does NOT assert on error - returns the model and error for caller to verify.
|
||||
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) (config.Model, error) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -54,10 +55,14 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
logger := zap.NewNop()
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
// 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
@@ -533,8 +538,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
|
||||
|
||||
var created map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &created)
|
||||
// API Key 被掩码
|
||||
assert.Contains(t, created["api_key"], "***")
|
||||
assert.Equal(t, "sk-test", created["api_key"])
|
||||
|
||||
// 获取时应包含 protocol
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
@@ -48,10 +49,14 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
logger := zap.NewNop()
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
require.NoError(t, registry.Register(openaiConv.NewAdapter()))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
@@ -30,10 +31,14 @@ func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
_ = service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
logger := zap.NewNop()
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
_ = service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
@@ -103,7 +108,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
|
||||
var providers []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &providers)
|
||||
assert.Len(t, providers, 1)
|
||||
assert.Contains(t, providers[0].APIKey, "***") // 已掩码
|
||||
assert.Equal(t, "sk-test-key", providers[0].APIKey)
|
||||
|
||||
// 4. 列出 Model
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
143
backend/tests/mocks/mock_model_repository.go
Normal file
143
backend/tests/mocks/mock_model_repository.go
Normal file
@@ -0,0 +1,143 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: model_repo.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockModelRepository is a mock of ModelRepository interface.
|
||||
type MockModelRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockModelRepositoryMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockModelRepositoryMockRecorder is the mock recorder for MockModelRepository.
|
||||
type MockModelRepositoryMockRecorder struct {
|
||||
mock *MockModelRepository
|
||||
}
|
||||
|
||||
// NewMockModelRepository creates a new mock instance.
|
||||
func NewMockModelRepository(ctrl *gomock.Controller) *MockModelRepository {
|
||||
mock := &MockModelRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockModelRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockModelRepository) EXPECT() *MockModelRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockModelRepository) Create(model *domain.Model) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Create", model)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create.
|
||||
func (mr *MockModelRepositoryMockRecorder) Create(model any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockModelRepository)(nil).Create), model)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockModelRepository) Delete(id string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockModelRepositoryMockRecorder) Delete(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockModelRepository)(nil).Delete), id)
|
||||
}
|
||||
|
||||
// FindByProviderAndModelName mocks base method.
|
||||
func (m *MockModelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FindByProviderAndModelName", providerID, modelName)
|
||||
ret0, _ := ret[0].(*domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindByProviderAndModelName indicates an expected call of FindByProviderAndModelName.
|
||||
func (mr *MockModelRepositoryMockRecorder) FindByProviderAndModelName(providerID, modelName any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByProviderAndModelName", reflect.TypeOf((*MockModelRepository)(nil).FindByProviderAndModelName), providerID, modelName)
|
||||
}
|
||||
|
||||
// GetByID mocks base method.
|
||||
func (m *MockModelRepository) GetByID(id string) (*domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByID", id)
|
||||
ret0, _ := ret[0].(*domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByID indicates an expected call of GetByID.
|
||||
func (mr *MockModelRepositoryMockRecorder) GetByID(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockModelRepository)(nil).GetByID), id)
|
||||
}
|
||||
|
||||
// List mocks base method.
|
||||
func (m *MockModelRepository) List(providerID string) ([]domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List", providerID)
|
||||
ret0, _ := ret[0].([]domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List.
|
||||
func (mr *MockModelRepositoryMockRecorder) List(providerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockModelRepository)(nil).List), providerID)
|
||||
}
|
||||
|
||||
// ListEnabled mocks base method.
|
||||
func (m *MockModelRepository) ListEnabled() ([]domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListEnabled")
|
||||
ret0, _ := ret[0].([]domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListEnabled indicates an expected call of ListEnabled.
|
||||
func (mr *MockModelRepositoryMockRecorder) ListEnabled() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabled", reflect.TypeOf((*MockModelRepository)(nil).ListEnabled))
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockModelRepository) Update(id string, updates map[string]any) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", id, updates)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockModelRepositoryMockRecorder) Update(id, updates any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockModelRepository)(nil).Update), id, updates)
|
||||
}
|
||||
128
backend/tests/mocks/mock_model_service.go
Normal file
128
backend/tests/mocks/mock_model_service.go
Normal file
@@ -0,0 +1,128 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: model_service.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockModelService is a mock of ModelService interface.
|
||||
type MockModelService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockModelServiceMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockModelServiceMockRecorder is the mock recorder for MockModelService.
|
||||
type MockModelServiceMockRecorder struct {
|
||||
mock *MockModelService
|
||||
}
|
||||
|
||||
// NewMockModelService creates a new mock instance.
|
||||
func NewMockModelService(ctrl *gomock.Controller) *MockModelService {
|
||||
mock := &MockModelService{ctrl: ctrl}
|
||||
mock.recorder = &MockModelServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockModelService) EXPECT() *MockModelServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockModelService) Create(model *domain.Model) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Create", model)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create.
|
||||
func (mr *MockModelServiceMockRecorder) Create(model any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockModelService)(nil).Create), model)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockModelService) Delete(id string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockModelServiceMockRecorder) Delete(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockModelService)(nil).Delete), id)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockModelService) Get(id string) (*domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", id)
|
||||
ret0, _ := ret[0].(*domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockModelServiceMockRecorder) Get(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockModelService)(nil).Get), id)
|
||||
}
|
||||
|
||||
// List mocks base method.
|
||||
func (m *MockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List", providerID)
|
||||
ret0, _ := ret[0].([]domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List.
|
||||
func (mr *MockModelServiceMockRecorder) List(providerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockModelService)(nil).List), providerID)
|
||||
}
|
||||
|
||||
// ListEnabled mocks base method.
|
||||
func (m *MockModelService) ListEnabled() ([]domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListEnabled")
|
||||
ret0, _ := ret[0].([]domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListEnabled indicates an expected call of ListEnabled.
|
||||
func (mr *MockModelServiceMockRecorder) ListEnabled() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabled", reflect.TypeOf((*MockModelService)(nil).ListEnabled))
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockModelService) Update(id string, updates map[string]any) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", id, updates)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockModelServiceMockRecorder) Update(id, updates any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockModelService)(nil).Update), id, updates)
|
||||
}
|
||||
73
backend/tests/mocks/mock_provider_client.go
Normal file
73
backend/tests/mocks/mock_provider_client.go
Normal file
@@ -0,0 +1,73 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: client.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
conversion "nex/backend/internal/conversion"
|
||||
provider "nex/backend/internal/provider"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockProviderClient is a mock of ProviderClient interface.
|
||||
type MockProviderClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProviderClientMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockProviderClientMockRecorder is the mock recorder for MockProviderClient.
|
||||
type MockProviderClientMockRecorder struct {
|
||||
mock *MockProviderClient
|
||||
}
|
||||
|
||||
// NewMockProviderClient creates a new mock instance.
|
||||
func NewMockProviderClient(ctrl *gomock.Controller) *MockProviderClient {
|
||||
mock := &MockProviderClient{ctrl: ctrl}
|
||||
mock.recorder = &MockProviderClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProviderClient) EXPECT() *MockProviderClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Send mocks base method.
|
||||
func (m *MockProviderClient) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Send", ctx, spec)
|
||||
ret0, _ := ret[0].(*conversion.HTTPResponseSpec)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Send indicates an expected call of Send.
|
||||
func (mr *MockProviderClientMockRecorder) Send(ctx, spec any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockProviderClient)(nil).Send), ctx, spec)
|
||||
}
|
||||
|
||||
// SendStream mocks base method.
|
||||
func (m *MockProviderClient) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SendStream", ctx, spec)
|
||||
ret0, _ := ret[0].(<-chan provider.StreamEvent)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SendStream indicates an expected call of SendStream.
|
||||
func (mr *MockProviderClientMockRecorder) SendStream(ctx, spec any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendStream", reflect.TypeOf((*MockProviderClient)(nil).SendStream), ctx, spec)
|
||||
}
|
||||
113
backend/tests/mocks/mock_provider_repository.go
Normal file
113
backend/tests/mocks/mock_provider_repository.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: provider_repo.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockProviderRepository is a mock of ProviderRepository interface.
|
||||
type MockProviderRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProviderRepositoryMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockProviderRepositoryMockRecorder is the mock recorder for MockProviderRepository.
|
||||
type MockProviderRepositoryMockRecorder struct {
|
||||
mock *MockProviderRepository
|
||||
}
|
||||
|
||||
// NewMockProviderRepository creates a new mock instance.
|
||||
func NewMockProviderRepository(ctrl *gomock.Controller) *MockProviderRepository {
|
||||
mock := &MockProviderRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockProviderRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProviderRepository) EXPECT() *MockProviderRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockProviderRepository) Create(provider *domain.Provider) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Create", provider)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create.
|
||||
func (mr *MockProviderRepositoryMockRecorder) Create(provider any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockProviderRepository)(nil).Create), provider)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockProviderRepository) Delete(id string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockProviderRepositoryMockRecorder) Delete(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProviderRepository)(nil).Delete), id)
|
||||
}
|
||||
|
||||
// GetByID mocks base method.
|
||||
func (m *MockProviderRepository) GetByID(id string) (*domain.Provider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetByID", id)
|
||||
ret0, _ := ret[0].(*domain.Provider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetByID indicates an expected call of GetByID.
|
||||
func (mr *MockProviderRepositoryMockRecorder) GetByID(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockProviderRepository)(nil).GetByID), id)
|
||||
}
|
||||
|
||||
// List mocks base method.
|
||||
func (m *MockProviderRepository) List() ([]domain.Provider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List")
|
||||
ret0, _ := ret[0].([]domain.Provider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List.
|
||||
func (mr *MockProviderRepositoryMockRecorder) List() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockProviderRepository)(nil).List))
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockProviderRepository) Update(id string, updates map[string]any) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", id, updates)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockProviderRepositoryMockRecorder) Update(id, updates any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProviderRepository)(nil).Update), id, updates)
|
||||
}
|
||||
143
backend/tests/mocks/mock_provider_service.go
Normal file
143
backend/tests/mocks/mock_provider_service.go
Normal file
@@ -0,0 +1,143 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: provider_service.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockProviderService is a mock of ProviderService interface.
|
||||
type MockProviderService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProviderServiceMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockProviderServiceMockRecorder is the mock recorder for MockProviderService.
|
||||
type MockProviderServiceMockRecorder struct {
|
||||
mock *MockProviderService
|
||||
}
|
||||
|
||||
// NewMockProviderService creates a new mock instance.
|
||||
func NewMockProviderService(ctrl *gomock.Controller) *MockProviderService {
|
||||
mock := &MockProviderService{ctrl: ctrl}
|
||||
mock.recorder = &MockProviderServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProviderService) EXPECT() *MockProviderServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Create mocks base method.
|
||||
func (m *MockProviderService) Create(provider *domain.Provider) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Create", provider)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Create indicates an expected call of Create.
|
||||
func (mr *MockProviderServiceMockRecorder) Create(provider any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockProviderService)(nil).Create), provider)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockProviderService) Delete(id string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockProviderServiceMockRecorder) Delete(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProviderService)(nil).Delete), id)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockProviderService) Get(id string) (*domain.Provider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", id)
|
||||
ret0, _ := ret[0].(*domain.Provider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockProviderServiceMockRecorder) Get(id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockProviderService)(nil).Get), id)
|
||||
}
|
||||
|
||||
// GetModelByProviderAndName mocks base method.
|
||||
func (m *MockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetModelByProviderAndName", providerID, modelName)
|
||||
ret0, _ := ret[0].(*domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetModelByProviderAndName indicates an expected call of GetModelByProviderAndName.
|
||||
func (mr *MockProviderServiceMockRecorder) GetModelByProviderAndName(providerID, modelName any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetModelByProviderAndName", reflect.TypeOf((*MockProviderService)(nil).GetModelByProviderAndName), providerID, modelName)
|
||||
}
|
||||
|
||||
// List mocks base method.
|
||||
func (m *MockProviderService) List() ([]domain.Provider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "List")
|
||||
ret0, _ := ret[0].([]domain.Provider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List.
|
||||
func (mr *MockProviderServiceMockRecorder) List() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockProviderService)(nil).List))
|
||||
}
|
||||
|
||||
// ListEnabledModels mocks base method.
|
||||
func (m *MockProviderService) ListEnabledModels() ([]domain.Model, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListEnabledModels")
|
||||
ret0, _ := ret[0].([]domain.Model)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListEnabledModels indicates an expected call of ListEnabledModels.
|
||||
func (mr *MockProviderServiceMockRecorder) ListEnabledModels() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabledModels", reflect.TypeOf((*MockProviderService)(nil).ListEnabledModels))
|
||||
}
|
||||
|
||||
// Update mocks base method.
|
||||
func (m *MockProviderService) Update(id string, updates map[string]any) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", id, updates)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update.
|
||||
func (mr *MockProviderServiceMockRecorder) Update(id, updates any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProviderService)(nil).Update), id, updates)
|
||||
}
|
||||
56
backend/tests/mocks/mock_routing_service.go
Normal file
56
backend/tests/mocks/mock_routing_service.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: routing_service.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockRoutingService is a mock of RoutingService interface.
|
||||
type MockRoutingService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRoutingServiceMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockRoutingServiceMockRecorder is the mock recorder for MockRoutingService.
|
||||
type MockRoutingServiceMockRecorder struct {
|
||||
mock *MockRoutingService
|
||||
}
|
||||
|
||||
// NewMockRoutingService creates a new mock instance.
|
||||
func NewMockRoutingService(ctrl *gomock.Controller) *MockRoutingService {
|
||||
mock := &MockRoutingService{ctrl: ctrl}
|
||||
mock.recorder = &MockRoutingServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRoutingService) EXPECT() *MockRoutingServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// RouteByModelName mocks base method.
|
||||
func (m *MockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RouteByModelName", providerID, modelName)
|
||||
ret0, _ := ret[0].(*domain.RouteResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RouteByModelName indicates an expected call of RouteByModelName.
|
||||
func (mr *MockRoutingServiceMockRecorder) RouteByModelName(providerID, modelName any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteByModelName", reflect.TypeOf((*MockRoutingService)(nil).RouteByModelName), providerID, modelName)
|
||||
}
|
||||
85
backend/tests/mocks/mock_stats_repository.go
Normal file
85
backend/tests/mocks/mock_stats_repository.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: stats_repo.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockStatsRepository is a mock of StatsRepository interface.
|
||||
type MockStatsRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStatsRepositoryMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockStatsRepositoryMockRecorder is the mock recorder for MockStatsRepository.
|
||||
type MockStatsRepositoryMockRecorder struct {
|
||||
mock *MockStatsRepository
|
||||
}
|
||||
|
||||
// NewMockStatsRepository creates a new mock instance.
|
||||
func NewMockStatsRepository(ctrl *gomock.Controller) *MockStatsRepository {
|
||||
mock := &MockStatsRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockStatsRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStatsRepository) EXPECT() *MockStatsRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BatchUpdate mocks base method.
|
||||
func (m *MockStatsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BatchUpdate", providerID, modelName, date, delta)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BatchUpdate indicates an expected call of BatchUpdate.
|
||||
func (mr *MockStatsRepositoryMockRecorder) BatchUpdate(providerID, modelName, date, delta any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdate", reflect.TypeOf((*MockStatsRepository)(nil).BatchUpdate), providerID, modelName, date, delta)
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockStatsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Query", providerID, modelName, startDate, endDate)
|
||||
ret0, _ := ret[0].([]domain.UsageStats)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockStatsRepositoryMockRecorder) Query(providerID, modelName, startDate, endDate any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStatsRepository)(nil).Query), providerID, modelName, startDate, endDate)
|
||||
}
|
||||
|
||||
// Record mocks base method.
|
||||
func (m *MockStatsRepository) Record(providerID, modelName string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Record", providerID, modelName)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Record indicates an expected call of Record.
|
||||
func (mr *MockStatsRepositoryMockRecorder) Record(providerID, modelName any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Record", reflect.TypeOf((*MockStatsRepository)(nil).Record), providerID, modelName)
|
||||
}
|
||||
85
backend/tests/mocks/mock_stats_service.go
Normal file
85
backend/tests/mocks/mock_stats_service.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: stats_service.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
domain "nex/backend/internal/domain"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockStatsService is a mock of StatsService interface.
|
||||
type MockStatsService struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStatsServiceMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockStatsServiceMockRecorder is the mock recorder for MockStatsService.
|
||||
type MockStatsServiceMockRecorder struct {
|
||||
mock *MockStatsService
|
||||
}
|
||||
|
||||
// NewMockStatsService creates a new mock instance.
|
||||
func NewMockStatsService(ctrl *gomock.Controller) *MockStatsService {
|
||||
mock := &MockStatsService{ctrl: ctrl}
|
||||
mock.recorder = &MockStatsServiceMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStatsService) EXPECT() *MockStatsServiceMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Aggregate mocks base method.
|
||||
func (m *MockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]any {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Aggregate", stats, groupBy)
|
||||
ret0, _ := ret[0].([]map[string]any)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Aggregate indicates an expected call of Aggregate.
|
||||
func (mr *MockStatsServiceMockRecorder) Aggregate(stats, groupBy any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockStatsService)(nil).Aggregate), stats, groupBy)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", providerID, modelName, startDate, endDate)
|
||||
ret0, _ := ret[0].([]domain.UsageStats)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockStatsServiceMockRecorder) Get(providerID, modelName, startDate, endDate any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStatsService)(nil).Get), providerID, modelName, startDate, endDate)
|
||||
}
|
||||
|
||||
// Record mocks base method.
|
||||
func (m *MockStatsService) Record(providerID, modelName string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Record", providerID, modelName)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Record indicates an expected call of Record.
|
||||
func (mr *MockStatsServiceMockRecorder) Record(providerID, modelName any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Record", reflect.TypeOf((*MockStatsService)(nil).Record), providerID, modelName)
|
||||
}
|
||||
@@ -126,24 +126,54 @@
|
||||
|
||||
### 2.3 请求处理流程
|
||||
|
||||
每个 HTTP 请求的转换流程:
|
||||
#### 2.3.1 三车道数据流模型
|
||||
|
||||
引擎根据协议匹配情况和参数条件,选择三条数据流通道之一:
|
||||
|
||||
```
|
||||
客户端入站 调用方(协议识别+前缀剥离) SDK 内部处理 上游出站
|
||||
┌──────────────────┐ ┌──────────────────┐
|
||||
│ URL: │ 调用方完成: 1. 接口识别: CHAT │ URL: │
|
||||
│ /<protocol>/ │ · clientProtocol 2. 同协议? ──yes──▶ 直接转发│ 目标协议 │
|
||||
│ v1/... │ · nativePath └──no──▶ 继续转换 │ 原生路径 │
|
||||
│ Headers: │ · providerProtocol 3. URL 映射: 目标路径 │ Headers: │
|
||||
│ 协议原生格式 │ 4. Header 构建: 目标格式 │ 目标协议格式 │
|
||||
│ Body: │ 5. Body 转换: D→C→E │ Body: │
|
||||
│ 协议原生格式 │ │ 目标协议格式 │
|
||||
└──────────────────┘ └──────────────────┘
|
||||
┌──────────────────────────────────────────────────────────────────────────────────┐
|
||||
│ 数据流三车道模型 │
|
||||
├────────────────┬─────────────────────────┬────────────────────────────────────────┤
|
||||
│ 透传车道 │ 智能透传车道 │ 完整转换车道 │
|
||||
├────────────────┼─────────────────────────┼────────────────────────────────────────┤
|
||||
│ 触发条件 │ 同协议 │ 同协议 + 接口∈{Chat,Embed,Rerank} │ 不同协议 │
|
||||
│ │ │ + Body非空 + ModelName非空 │ │
|
||||
│ │ │ │ │
|
||||
│ 请求处理 │ 重建Headers │ 重建Headers │ Decode→Middleware→Encode│
|
||||
│ │ Body原样转发 │ RewriteRequestModelName(body) │ │
|
||||
│ │ │ (最小化JSON字段手术) │ │
|
||||
│ │ │ │ │
|
||||
│ 响应处理 │ 原样返回 │ modelOverride非空时 │ Decode→modelOverride │
|
||||
│ │ │ RewriteResponseModelName(body) │ →Encode │
|
||||
│ │ │ │ │
|
||||
│ 流式处理 │ chunk→[chunk] │ chunk→RewriteResponseModelName │ Decode→Middleware │
|
||||
│ │ │ →[rewritten] │ →modelOverride→Encode │
|
||||
│ │ │ │ │
|
||||
│ 性能开销 │ 最低 │ 低(仅JSON字段改写) │ 高(完整序列化) │
|
||||
└────────────────┴─────────────────────────┴────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
响应方向同理(含流式)。D=Decoder, C=Canonical, E=Encoder。
|
||||
**智能透传的设计动机**:同协议场景下,若仅需改写 `model` 字段(如客户端请求模型 "X",上游需要模型 "Y"),无需完整解码/编码。直接在 JSON 层面手术式改写该字段,既保留原始请求的所有细节,又避免序列化开销。
|
||||
|
||||
#### 2.3.2 完整请求处理流程
|
||||
|
||||
```
|
||||
客户端入站 调用方职责 SDK 内部处理 上游出站
|
||||
┌──────────────┐ ┌──────────────┐
|
||||
│ URL: │ 调用方完成: 1. 接口识别 │ URL: │
|
||||
│ /<protocol>/ │ · clientProtocol 2. IsPassthrough? │ 目标协议 │
|
||||
│ v1/... │ · nativePath ├─ yes ─┬─ 无 modelOverride → 透传车道 │ 原生路径 │
|
||||
│ Headers: │ · providerProtocol │ └─ 有 modelOverride → 智能透传车道│ Headers: │
|
||||
│ 协议原生格式 │ │ │ 目标协议格式 │
|
||||
│ Body: │ └─ no → 完整转换车道 │ Body: │
|
||||
│ 协议原生格式 │ │ 目标协议格式 │
|
||||
└──────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
响应方向同理(含流式)。
|
||||
|
||||
**同协议透传**:client == provider 时,仅重建 Header 后原样转发到上游。
|
||||
**智能透传**:同协议且需改写 model 字段时,最小化 JSON 改写后转发。
|
||||
**未知接口透传**:无法识别的路径,URL+Header 适配后 Body 原样转发。
|
||||
|
||||
---
|
||||
@@ -229,13 +259,15 @@ ContentBlock = Union<
|
||||
content: Union<String, Array<ContentBlock>>,
|
||||
is_error: Option<Boolean> }
|
||||
ThinkingBlock, { type: "thinking", thinking: String }
|
||||
ImageBlock, { type: "image", source: ... } // 多模态预留
|
||||
AudioBlock, { type: "audio", source: ... } // 多模态预留
|
||||
VideoBlock, { type: "video", source: ... } // 多模态预留
|
||||
FileBlock { type: "file", source: ... } // 多模态预留
|
||||
ImageBlock, { type: "image", source: ... } // Deferred:多模态预留
|
||||
AudioBlock, { type: "audio", source: ... } // Deferred:多模态预留
|
||||
VideoBlock, { type: "video", source: ... } // Deferred:多模态预留
|
||||
FileBlock { type: "file", source: ... } // Deferred:多模态预留
|
||||
>
|
||||
```
|
||||
|
||||
**当前实现状态**:仅实现了 TextBlock、ToolUseBlock、ToolResultBlock、ThinkingBlock。多模态类型(ImageBlock、AudioBlock、VideoBlock、FileBlock)为预留扩展点。
|
||||
|
||||
### 4.4 CanonicalTool / ToolChoice
|
||||
|
||||
```
|
||||
@@ -400,7 +432,7 @@ interface ProtocolAdapter {
|
||||
createStreamEncoder(): StreamEncoder
|
||||
|
||||
// 错误编码
|
||||
encodeError(error: ConversionError): RawResponse
|
||||
encodeError(error: ConversionError): (body, statusCode)
|
||||
|
||||
// 扩展层
|
||||
decodeModelsResponse(raw): CanonicalModelList
|
||||
@@ -415,20 +447,40 @@ interface ProtocolAdapter {
|
||||
encodeRerankRequest(canonical, provider): RawRequest
|
||||
decodeRerankResponse(raw): CanonicalRerankResponse
|
||||
encodeRerankResponse(canonical): RawResponse
|
||||
|
||||
// 智能透传支持
|
||||
extractUnifiedModelID(nativePath: String): (modelID, error) // 从路径提取统一模型 ID
|
||||
extractModelName(body: Raw, interfaceType: InterfaceType): (modelName, error) // 从请求体提取 model 字段值
|
||||
rewriteRequestModelName(body: Raw, newModel: String, interfaceType: InterfaceType): RawRequest // 最小化改写请求体中的 model 字段
|
||||
rewriteResponseModelName(body: Raw, newModel: String, interfaceType: InterfaceType): RawResponse // 最小化改写响应体中的 model 字段
|
||||
}
|
||||
```
|
||||
|
||||
**`buildHeaders` 的设计**:Adapter 只需从 `provider` 中提取自己协议需要的认证和配置信息,构建自己的 Header 格式。不再需要理解其他协议的 Header。
|
||||
|
||||
**智能透传方法的契约**:
|
||||
- `rewriteRequestModelName` / `rewriteResponseModelName` 必须**幂等**(多次调用结果相同)
|
||||
- Rewrite 方法必须**最小化**(仅修改 model 字段,不触碰其他字段)
|
||||
- Rewrite 失败时,引擎使用宽容策略:记录警告日志,使用原始 body 继续处理
|
||||
- `extractModelName` 支持的接口类型:CHAT、EMBEDDINGS、RERANK(这些接口的请求体包含 model 字段)
|
||||
|
||||
### 5.3 InterfaceType
|
||||
|
||||
```
|
||||
InterfaceType = Enum<
|
||||
CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK,
|
||||
AUDIO, IMAGES // 预留:多模态扩展时启用
|
||||
CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK, PASSTHROUGH
|
||||
>
|
||||
```
|
||||
|
||||
| 类型 | 说明 |
|
||||
|------|------|
|
||||
| CHAT | 核心对话接口(各协议的 Chat/Messages 接口) |
|
||||
| MODELS | 模型列表接口 |
|
||||
| MODEL_INFO | 模型详情接口 |
|
||||
| EMBEDDINGS | 向量嵌入接口 |
|
||||
| RERANK | 重排序接口 |
|
||||
| PASSTHROUGH | 未知接口,透传处理 |
|
||||
|
||||
### 5.4 StreamDecoder / StreamEncoder
|
||||
|
||||
```
|
||||
@@ -463,68 +515,138 @@ ConversionEngine 是无状态的格式转换工具,仅做协议间的编解码
|
||||
|
||||
**协议识别**:`clientProtocol` 和 `providerProtocol` 由调用方确定并传入引擎(详见 §2.2)。
|
||||
|
||||
#### 6.1.1 HTTP 规格
|
||||
|
||||
```
|
||||
HTTPRequestSpec {
|
||||
url: String // 请求路径(不含 base_url)
|
||||
method: String // HTTP 方法
|
||||
headers: Map<String, String>
|
||||
body: ByteArray // 原始请求体
|
||||
}
|
||||
|
||||
HTTPResponseSpec {
|
||||
statusCode: Integer
|
||||
headers: Map<String, String>
|
||||
body: ByteArray // 原始响应体
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.1.2 引擎接口
|
||||
|
||||
```
|
||||
class ConversionEngine {
|
||||
registry: AdapterRegistry
|
||||
middlewareChain: MiddlewareChain
|
||||
|
||||
// 生命周期
|
||||
registerAdapter(adapter): void
|
||||
use(middleware): void
|
||||
getRegistry(): AdapterRegistry
|
||||
|
||||
// 核心转换
|
||||
isPassthrough(clientProtocol, providerProtocol): Boolean {
|
||||
return clientProtocol == providerProtocol && registry.get(clientProtocol).supportsPassthrough()
|
||||
}
|
||||
|
||||
// 非流式请求转换
|
||||
convertHttpRequest(request, clientProtocol, providerProtocol, provider): HttpRequest {
|
||||
nativePath = request.url
|
||||
convertHttpRequest(request, clientProtocol, providerProtocol, provider): HTTPRequestSpec
|
||||
convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType, modelOverride): HTTPResponseSpec
|
||||
createStreamConverter(clientProtocol, providerProtocol, modelOverride, interfaceType): StreamConverter
|
||||
|
||||
if isPassthrough(clientProtocol, providerProtocol):
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
return {url: provider.base_url + nativePath, method: request.method,
|
||||
headers: providerAdapter.buildHeaders(provider), body: request.body}
|
||||
|
||||
clientAdapter = registry.get(clientProtocol)
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
|
||||
// 使用 clientAdapter 识别接口类型
|
||||
interfaceType = clientAdapter.detectInterfaceType(nativePath)
|
||||
|
||||
providerUrl = providerAdapter.buildUrl(nativePath, interfaceType)
|
||||
providerHeaders = providerAdapter.buildHeaders(provider)
|
||||
providerBody = convertBody(interfaceType, clientAdapter, providerAdapter, provider, request.body)
|
||||
|
||||
return {url: provider.base_url + providerUrl, method: request.method,
|
||||
headers: providerHeaders, body: providerBody}
|
||||
}
|
||||
|
||||
// 非流式响应转换
|
||||
convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType): HttpResponse {
|
||||
if isPassthrough(clientProtocol, providerProtocol): return response
|
||||
|
||||
clientAdapter = registry.get(clientProtocol)
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
providerBody = convertResponseBody(interfaceType, clientAdapter, providerAdapter, response.body)
|
||||
|
||||
return {status: response.status, headers: response.headers, body: providerBody}
|
||||
}
|
||||
|
||||
// 流式转换:从 provider 协议解码,编码为 client 协议
|
||||
createStreamConverter(clientProtocol, providerProtocol, provider): StreamConverter {
|
||||
if isPassthrough(clientProtocol, providerProtocol):
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
return new PassthroughStreamConverter(providerAdapter.buildHeaders(provider))
|
||||
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
clientAdapter = registry.get(clientProtocol)
|
||||
return new CanonicalStreamConverter(
|
||||
providerAdapter.createStreamDecoder(), clientAdapter.createStreamEncoder(), middlewareChain)
|
||||
}
|
||||
// 辅助方法
|
||||
detectInterfaceType(nativePath, clientProtocol): InterfaceType
|
||||
encodeError(error, clientProtocol): (body, statusCode)
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.1.3 请求转换流程
|
||||
|
||||
```
|
||||
function convertHttpRequest(request, clientProtocol, providerProtocol, provider):
|
||||
nativePath = request.url
|
||||
|
||||
if isPassthrough(clientProtocol, providerProtocol):
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
interfaceType = providerAdapter.detectInterfaceType(nativePath)
|
||||
|
||||
// 智能透传:对 Chat/Embeddings/Rerank 接口改写 model 字段
|
||||
if interfaceType in {CHAT, EMBEDDINGS, RERANK} and request.body非空 and provider.modelName非空:
|
||||
rewrittenBody = providerAdapter.rewriteRequestModelName(request.body, provider.modelName, interfaceType)
|
||||
if rewrite失败:
|
||||
log.warn("智能透传改写失败,使用原始请求体")
|
||||
rewrittenBody = request.body
|
||||
else:
|
||||
rewrittenBody = request.body
|
||||
|
||||
return HTTPRequestSpec {
|
||||
url: provider.base_url + nativePath,
|
||||
method: request.method,
|
||||
headers: providerAdapter.buildHeaders(provider),
|
||||
body: rewrittenBody
|
||||
}
|
||||
|
||||
// 完整转换车道
|
||||
clientAdapter = registry.get(clientProtocol)
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
interfaceType = clientAdapter.detectInterfaceType(nativePath)
|
||||
|
||||
providerUrl = providerAdapter.buildUrl(nativePath, interfaceType)
|
||||
providerHeaders = providerAdapter.buildHeaders(provider)
|
||||
providerBody = convertBody(interfaceType, clientAdapter, providerAdapter, provider, request.body)
|
||||
|
||||
return HTTPRequestSpec {
|
||||
url: provider.base_url + providerUrl,
|
||||
method: request.method,
|
||||
headers: providerHeaders,
|
||||
body: providerBody
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.1.4 响应转换流程
|
||||
|
||||
```
|
||||
function convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType, modelOverride):
|
||||
if isPassthrough(clientProtocol, providerProtocol):
|
||||
// 智能透传:改写响应体中的 model 字段
|
||||
if modelOverride非空 and response.body非空:
|
||||
adapter = registry.get(clientProtocol)
|
||||
rewrittenBody = adapter.rewriteResponseModelName(response.body, modelOverride, interfaceType)
|
||||
if rewrite失败:
|
||||
log.warn("智能透传改写失败,使用原始响应体")
|
||||
return response
|
||||
return HTTPResponseSpec { statusCode: response.statusCode, headers: response.headers, body: rewrittenBody }
|
||||
return response
|
||||
|
||||
// 完整转换车道
|
||||
clientAdapter = registry.get(clientProtocol)
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
convertedBody = convertResponseBody(interfaceType, clientAdapter, providerAdapter, response.body, modelOverride)
|
||||
|
||||
return HTTPResponseSpec { statusCode: response.statusCode, headers: response.headers, body: convertedBody }
|
||||
```
|
||||
|
||||
**modelOverride 参数语义**:跨协议场景下,客户端期望看到的模型名。在响应转换时直接覆写 `canonicalResponse.model = modelOverride`,在流式转换时覆写 `event.message.model = modelOverride`。
|
||||
|
||||
### 6.2 Body 转换分发
|
||||
|
||||
#### 6.2.1 接口类型分发策略
|
||||
|
||||
| 接口类型 | 请求转换 | 响应转换 | 中间件 | modelOverride |
|
||||
|---------|---------|---------|--------|---------------|
|
||||
| CHAT | Decode→**Middleware**→Encode | Decode→modelOverride→Encode | **应用** | 支持(响应) |
|
||||
| MODELS | Body透传(GET) | Decode→Encode | 不应用 | 不支持 |
|
||||
| MODEL_INFO | Body透传(GET) | Decode→Encode | 不应用 | 不支持 |
|
||||
| EMBEDDINGS | Decode→Encode | Decode→modelOverride→Encode | 不应用 | 支持(响应) |
|
||||
| RERANK | Decode→Encode | Decode→modelOverride→Encode | 不应用 | 支持(响应) |
|
||||
| PASSTHROUGH | Body透传 | Body透传 | 不应用 | 不支持 |
|
||||
|
||||
**关键说明**:
|
||||
- **只有 CHAT 接口**走完整的 `decode → middleware → encode` 管道
|
||||
- **扩展层接口**(EMBEDDINGS、RERANK)**跳过中间件**,直接 decode → encode
|
||||
- **扩展层响应支持 modelOverride**,在 encode 前直接覆写 canonical 字段
|
||||
|
||||
#### 6.2.2 请求体转换
|
||||
|
||||
```
|
||||
function convertBody(interfaceType, clientAdapter, providerAdapter, provider, body):
|
||||
switch interfaceType:
|
||||
@@ -546,41 +668,146 @@ function convertBody(interfaceType, clientAdapter, providerAdapter, provider, bo
|
||||
// 同 EMBEDDINGS 模式
|
||||
default:
|
||||
return body // 透传层:原样转发
|
||||
```
|
||||
|
||||
function convertResponseBody(interfaceType, clientAdapter, providerAdapter, body):
|
||||
// 结构与 convertBody 对称,CHAT 走 Canonical 深度转换,扩展层走轻量映射,默认透传
|
||||
// 各接口的具体响应转换逻辑详见各协议适配文档(附录 E)
|
||||
#### 6.2.3 响应体转换
|
||||
|
||||
```
|
||||
function convertResponseBody(interfaceType, clientAdapter, providerAdapter, body, modelOverride):
|
||||
switch interfaceType:
|
||||
case CHAT:
|
||||
canonical = providerAdapter.decodeResponse(body)
|
||||
if modelOverride非空: canonical.model = modelOverride
|
||||
return clientAdapter.encodeResponse(canonical)
|
||||
case MODELS:
|
||||
if !clientAdapter.supportsInterface(MODELS) || !providerAdapter.supportsInterface(MODELS):
|
||||
return body
|
||||
return clientAdapter.encodeModelsResponse(providerAdapter.decodeModelsResponse(body))
|
||||
case MODEL_INFO:
|
||||
// 同 MODELS 模式
|
||||
case EMBEDDINGS:
|
||||
if !clientAdapter.supportsInterface(EMBEDDINGS) || !providerAdapter.supportsInterface(EMBEDDINGS):
|
||||
return body
|
||||
canonical = providerAdapter.decodeEmbeddingResponse(body)
|
||||
if modelOverride非空: canonical.model = modelOverride
|
||||
return clientAdapter.encodeEmbeddingResponse(canonical)
|
||||
case RERANK:
|
||||
// 同 EMBEDDINGS 模式,支持 modelOverride
|
||||
default:
|
||||
return body // 透传层:原样转发
|
||||
```
|
||||
|
||||
### 6.3 StreamConverter
|
||||
|
||||
#### 6.3.1 接口定义
|
||||
|
||||
```
|
||||
interface StreamConverter {
|
||||
processChunk(rawChunk): Array<RawSSEChunk>
|
||||
flush(): Array<RawSSEChunk>
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.3.2 三种转换器变体
|
||||
|
||||
| 转换器 | 触发条件 | processChunk | flush |
|
||||
|--------|---------|--------------|-------|
|
||||
| `PassthroughStreamConverter` | 同协议 + 无 modelOverride | `[rawChunk]` | `[]` |
|
||||
| `SmartPassthroughStreamConverter` | 同协议 + 有 modelOverride | `[rewriteResponseModelName(rawChunk)]` | `[]` |
|
||||
| `CanonicalStreamConverter` | 不同协议 | Decode→Middleware→modelOverride→Encode | decoder.flush()→encoder.flush() |
|
||||
|
||||
#### 6.3.3 PassthroughStreamConverter
|
||||
|
||||
```
|
||||
class PassthroughStreamConverter implements StreamConverter {
|
||||
headers: Map<String, String>
|
||||
constructor(headers) { this.headers = headers }
|
||||
processChunk(rawChunk): Array<RawSSEChunk> { return [rawChunk] }
|
||||
flush(): Array<RawSSEChunk> { return [] }
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.3.4 SmartPassthroughStreamConverter
|
||||
|
||||
```
|
||||
class SmartPassthroughStreamConverter implements StreamConverter {
|
||||
adapter: ProtocolAdapter
|
||||
modelOverride: String
|
||||
interfaceType: InterfaceType
|
||||
|
||||
processChunk(rawChunk): Array<RawSSEChunk> {
|
||||
if rawChunk为空: return []
|
||||
rewrittenChunk = adapter.rewriteResponseModelName(rawChunk, modelOverride, interfaceType)
|
||||
if rewrite失败:
|
||||
log.warn("智能透传改写失败,使用原始 chunk")
|
||||
return [rawChunk]
|
||||
return [rewrittenChunk]
|
||||
}
|
||||
flush(): Array<RawSSEChunk> { return [] }
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.3.5 CanonicalStreamConverter
|
||||
|
||||
```
|
||||
class CanonicalStreamConverter implements StreamConverter {
|
||||
decoder: StreamDecoder
|
||||
encoder: StreamEncoder
|
||||
middleware: MiddlewareChain
|
||||
context: ConversionContext
|
||||
clientProtocol: String
|
||||
providerProtocol: String
|
||||
modelOverride: String
|
||||
|
||||
processChunk(rawChunk):
|
||||
events = decoder.processChunk(rawChunk).map(e => middleware.applyStreamEvent(e))
|
||||
return events.flatMap(e => encoder.encodeEvent(e))
|
||||
events = decoder.processChunk(rawChunk)
|
||||
result = []
|
||||
for each event in events:
|
||||
// 中间件:转换 canonical 事件
|
||||
if middleware != null:
|
||||
processed, err = middleware.applyStreamEvent(event, clientProtocol, providerProtocol, context)
|
||||
if err != null:
|
||||
continue // 宽容策略:跳过错误事件,继续处理
|
||||
event = processed
|
||||
|
||||
// modelOverride:覆写 model 字段
|
||||
if modelOverride非空 and event.message != null:
|
||||
event.message.model = modelOverride
|
||||
|
||||
// 编码为目标协议 SSE
|
||||
chunks = encoder.encodeEvent(event)
|
||||
result.append(chunks)
|
||||
return result
|
||||
|
||||
flush():
|
||||
return decoder.flush().flatMap(e => encoder.encodeEvent(e)) + encoder.flush()
|
||||
events = decoder.flush()
|
||||
// 同 processChunk 的中间件 + modelOverride + encode 管道
|
||||
result = [经过管道处理的所有事件]
|
||||
encoderChunks = encoder.flush()
|
||||
result.append(encoderChunks)
|
||||
return result
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.3.6 流式转换器创建
|
||||
|
||||
```
|
||||
function createStreamConverter(clientProtocol, providerProtocol, modelOverride, interfaceType):
|
||||
if isPassthrough(clientProtocol, providerProtocol):
|
||||
if modelOverride非空:
|
||||
adapter = registry.get(clientProtocol)
|
||||
return new SmartPassthroughStreamConverter(adapter, modelOverride, interfaceType)
|
||||
return new PassthroughStreamConverter()
|
||||
|
||||
providerAdapter = registry.get(providerProtocol)
|
||||
clientAdapter = registry.get(clientProtocol)
|
||||
|
||||
return new CanonicalStreamConverter(
|
||||
decoder: providerAdapter.createStreamDecoder(),
|
||||
encoder: clientAdapter.createStreamEncoder(),
|
||||
middleware: middlewareChain,
|
||||
modelOverride: modelOverride
|
||||
)
|
||||
```
|
||||
|
||||
### 6.4 Middleware
|
||||
|
||||
引擎内部的拦截钩子,在 decode → encode 之间对 Canonical 进行变换。
|
||||
@@ -588,48 +815,73 @@ class CanonicalStreamConverter implements StreamConverter {
|
||||
```
|
||||
interface ConversionMiddleware {
|
||||
intercept(canonical, clientProtocol, providerProtocol, context): canonical | error
|
||||
interceptStreamEvent?(event, clientProtocol, providerProtocol, context): event | error
|
||||
interceptStreamEvent(event, clientProtocol, providerProtocol, context): event | error
|
||||
}
|
||||
|
||||
ConversionContext { conversionId, interfaceType, timestamp, metadata }
|
||||
ConversionContext {
|
||||
conversionId: String // 唯一转换 ID(UUID)
|
||||
interfaceType: InterfaceType
|
||||
timestamp: DateTime
|
||||
metadata: Map<String, Any>
|
||||
}
|
||||
```
|
||||
|
||||
#### 6.4.1 中间件执行规则
|
||||
|
||||
- `intercept` 返回修改后的 canonical,或返回 ConversionError 以**中断转换**
|
||||
- `interceptStreamEvent` 同理,返回错误可中断流式转换
|
||||
- `interceptStreamEvent` 返回修改后的 event,或返回 error
|
||||
- 多个 Middleware 按注册顺序链式执行,任一中断则后续不再执行
|
||||
|
||||
#### 6.4.2 错误处理差异
|
||||
|
||||
| 场景 | 错误处理策略 |
|
||||
|------|-------------|
|
||||
| 请求中间件 `intercept` 返回 error | **严格模式**:中断整个转换,返回错误 |
|
||||
| 流式中间件 `interceptStreamEvent` 返回 error | **宽容模式**:跳过该事件,继续处理后续事件 |
|
||||
|
||||
**设计动机**:流式场景下,单个事件的错误不应中断整个流。请求场景下,错误请求应被明确拒绝。
|
||||
|
||||
### 6.5 使用示例
|
||||
|
||||
```
|
||||
engine = new ConversionEngine()
|
||||
engine.registerAdapter(new ProtocolAAdapter())
|
||||
engine.registerAdapter(new ProtocolBAdapter())
|
||||
engine.registerAdapter(new OpenAIAdapter())
|
||||
engine.registerAdapter(new AnthropicAdapter())
|
||||
|
||||
// 场景1: 跨协议 Chat 转换
|
||||
// 入站: /protocol_a/v1/chat/completions
|
||||
// 入站: /openai/v1/chat/completions
|
||||
provider = TargetProvider {
|
||||
base_url: "https://api.protocol-b.com",
|
||||
base_url: "https://api.anthropic.com",
|
||||
api_key: "xxx",
|
||||
model_name: "model-b",
|
||||
adapter_config: { ... }
|
||||
model_name: "claude-3-opus",
|
||||
adapter_config: { anthropic_version: "2023-06-01" }
|
||||
}
|
||||
out = engine.convertHttpRequest(inRequest, "protocol_a", "protocol_b", provider)
|
||||
// 出站: 目标协议路径 + 目标协议 headers + 转换后的 body
|
||||
out = engine.convertHttpRequest(inRequest, "openai", "anthropic", provider)
|
||||
// 出站: /v1/messages + Anthropic headers + 转换后的 body
|
||||
|
||||
// 场景2: /models 跨协议
|
||||
out = engine.convertHttpRequest(inRequest, "protocol_a", "protocol_b", provider)
|
||||
// URL: /v1/models(通常不变), headers 按目标协议格式重建
|
||||
out = engine.convertHttpRequest(inRequest, "openai", "anthropic", provider)
|
||||
// URL: /v1/models, headers 按 Anthropic 格式重建
|
||||
|
||||
// 场景3: 同协议透传
|
||||
out = engine.convertHttpRequest(inRequest, "protocol_a", "protocol_a", provider)
|
||||
// client == provider → 剥离前缀, 用 provider 重建 headers 后原样转发
|
||||
out = engine.convertHttpRequest(inRequest, "openai", "openai", provider)
|
||||
// client == provider → 透传车道:重建 headers 后原样转发
|
||||
|
||||
// 场景4: 流式转换(从 provider 协议解码,编码为 client 协议)
|
||||
converter = engine.createStreamConverter("protocol_a", "protocol_b", provider)
|
||||
// 场景4: 同协议智能透传(改写 model 字段)
|
||||
provider = TargetProvider { model_name: "gpt-4-turbo", ... }
|
||||
out = engine.convertHttpRequest(inRequest, "openai", "openai", provider)
|
||||
// 智能透传车道:请求体中的 model 字段改写为 "gpt-4-turbo"
|
||||
|
||||
// 场景5: 流式转换(从 provider 协议解码,编码为 client 协议)
|
||||
converter = engine.createStreamConverter("openai", "anthropic", "claude-3-opus", CHAT)
|
||||
for chunk in upstreamSSE {
|
||||
for out in converter.processChunk(chunk) { sendToClient(out) }
|
||||
}
|
||||
converter.flush()
|
||||
|
||||
// 场景6: 同协议流式智能透传
|
||||
converter = engine.createStreamConverter("openai", "openai", "gpt-4-turbo", CHAT)
|
||||
// 使用 SmartPassthroughStreamConverter,逐 chunk 改写 model 字段
|
||||
```
|
||||
|
||||
---
|
||||
@@ -641,9 +893,13 @@ converter.flush()
|
||||
```
|
||||
上游 SSE 流
|
||||
│
|
||||
├── 同协议: PassthroughStreamConverter(用 provider 重建 Headers 后逐块转发)
|
||||
├─ 同协议 + 无 modelOverride: PassthroughStreamConverter
|
||||
│ chunk → [chunk]
|
||||
│
|
||||
└── 跨协议: CanonicalStreamConverter
|
||||
├─ 同协议 + 有 modelOverride: SmartPassthroughStreamConverter
|
||||
│ chunk → [rewriteResponseModelName(chunk)]
|
||||
│
|
||||
└─ 不同协议: CanonicalStreamConverter
|
||||
StreamDecoder StreamEncoder
|
||||
┌───────────┐ ┌───────────┐
|
||||
│ SSE Parser│ │SSE Writer │
|
||||
@@ -653,9 +909,17 @@ converter.flush()
|
||||
│ Event │──────────────────────▶│ Event │
|
||||
│ Translator │ ┌──────────┐ │ Translator │
|
||||
│ (状态机) │ │Middleware│ │ │
|
||||
└───────────┘ └──────────┘ └───────────┘
|
||||
└───────────┘ │(宽容模式) │ └───────────┘
|
||||
└──────────┘
|
||||
│
|
||||
┌─────▼─────┐
|
||||
│modelOverride│
|
||||
│(覆写 model) │
|
||||
└───────────┘
|
||||
```
|
||||
|
||||
**流式中间件错误处理**:`interceptStreamEvent` 返回 error 时,跳过该事件继续处理后续事件(宽容模式),而非中断整个流。
|
||||
|
||||
### 7.2 StreamDecoder 通用状态
|
||||
|
||||
StreamDecoder 需要跟踪以下通用状态。具体协议的 Decoder 可根据需要扩展:
|
||||
@@ -697,12 +961,16 @@ StreamDecoder 将协议原生 SSE 事件翻译为 CanonicalStreamEvent,StreamE
|
||||
|
||||
### 8.2 多模态扩展
|
||||
|
||||
**状态**:Deferred(未实现)
|
||||
|
||||
Canonical Model 已预留 ImageBlock / AudioBlock / VideoBlock / FileBlock。实现路径:
|
||||
1. 在各 ProtocolAdapter 中实现多模态块的编解码
|
||||
2. 在 StreamDecoder/StreamEncoder 中处理多模态增量数据
|
||||
|
||||
### 8.3 有状态特性扩展
|
||||
|
||||
**状态**:Deferred(未实现)
|
||||
|
||||
```
|
||||
interface StatefulMiddleware extends ConversionMiddleware {
|
||||
stateStore: SessionStateStore
|
||||
@@ -722,6 +990,8 @@ interface StatefulMiddleware extends ConversionMiddleware {
|
||||
|
||||
### 8.5 自定义接口支持
|
||||
|
||||
**状态**:Deferred(未实现)
|
||||
|
||||
```
|
||||
interface CustomInterfaceHandler {
|
||||
interfaceType(): InterfaceType
|
||||
@@ -732,6 +1002,8 @@ interface CustomInterfaceHandler {
|
||||
engine.registerCustomHandler(handler)
|
||||
```
|
||||
|
||||
当前实现中,未知接口直接走 PASSTHROUGH 透传。
|
||||
|
||||
---
|
||||
|
||||
## 9. 错误处理
|
||||
@@ -759,16 +1031,36 @@ ErrorCode = Enum<
|
||||
|
||||
### 9.2 错误处理策略
|
||||
|
||||
```
|
||||
ErrorHandler { mode: "strict" | "lenient" }
|
||||
引擎采用**分层宽容策略**,根据接口层级和场景选择不同的错误处理方式:
|
||||
|
||||
strict: 任何错误抛出异常
|
||||
lenient: 尽力继续
|
||||
INCOMPATIBLE_FEATURE → 降级继续
|
||||
INTERFACE_NOT_SUPPORTED → 透传或返回空响应
|
||||
TOOL_CALL_PARSE_ERROR → 保留原始内容继续
|
||||
PROTOCOL_CONSTRAINT_VIOLATION → 自动修复
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ 错误处理分层策略 │
|
||||
├─────────────────┬───────────────────────────────────────────────────┤
|
||||
│ 核心层(CHAT) │ 严格模式:任何错误返回 ConversionError,中断转换 │
|
||||
│ │ - Decode 失败 → 返回 JSON_PARSE_ERROR │
|
||||
│ │ - Middleware 失败 → 返回错误 │
|
||||
│ │ - Encode 失败 → 返回 ENCODING_FAILURE │
|
||||
├─────────────────┼───────────────────────────────────────────────────┤
|
||||
│ 扩展层 │ 宽容模式:记录警告日志,返回原始 body 透传 │
|
||||
│ (Models/Embed/ │ - Decode 失败 → log.warn + 返回原始 body │
|
||||
│ Rerank) │ - Encode 失败 → log.warn + 返回原始 body │
|
||||
├─────────────────┼───────────────────────────────────────────────────┤
|
||||
│ 流式中间件 │ 宽容模式:跳过错误事件,继续处理后续事件 │
|
||||
│ │ - interceptStreamEvent 返回 error → continue │
|
||||
├─────────────────┼───────────────────────────────────────────────────┤
|
||||
│ 智能透传 │ 宽容模式:重写失败则使用原始 body/chunk │
|
||||
│ │ - Rewrite 失败 → log.warn + 返回原始 body/chunk │
|
||||
├─────────────────┼───────────────────────────────────────────────────┤
|
||||
│ 请求中间件 │ 严格模式:返回错误则中断整个转换 │
|
||||
│ │ - intercept 返回 error → 返回 error │
|
||||
└─────────────────┴───────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**设计动机**:
|
||||
- 核心接口(CHAT)必须保证语义正确性,错误应明确暴露
|
||||
- 扩展层接口优先保证可用性,错误应降级处理
|
||||
- 流式场景不应因单个事件错误中断整个流
|
||||
|
||||
**不支持接口的处理**(`INTERFACE_NOT_SUPPORTED`):
|
||||
|
||||
@@ -778,7 +1070,7 @@ lenient: 尽力继续
|
||||
| 返回空响应 | 不影响核心功能 | 返回空列表 `{data: []}` |
|
||||
| 返回错误 | 客户端明确需要此功能 | 返回 501 或协议格式错误 |
|
||||
|
||||
具体策略通过配置或 Middleware 决定。
|
||||
具体策略由 `supportsInterface` 返回值决定:返回 false 时引擎直接透传 body。
|
||||
|
||||
### 9.3 错误响应格式
|
||||
|
||||
@@ -786,6 +1078,24 @@ lenient: 尽力继续
|
||||
|
||||
Middleware 中断转换时同理,引擎调用 clientAdapter.encodeError 将 ConversionError 编码为客户端可理解的格式。
|
||||
|
||||
#### 9.3.1 EncodeError Fallback 行为
|
||||
|
||||
当客户端适配器不可用时,引擎使用通用 JSON 错误作为 fallback:
|
||||
|
||||
```
|
||||
function encodeError(error, clientProtocol):
|
||||
adapter = registry.get(clientProtocol)
|
||||
if adapter不存在:
|
||||
// Fallback: 通用 JSON 错误
|
||||
return {
|
||||
"error": {
|
||||
"message": error.message,
|
||||
"type": "internal_error"
|
||||
}
|
||||
}, 500
|
||||
return adapter.encodeError(error)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 附录 A:模块依赖
|
||||
@@ -796,21 +1106,25 @@ Middleware 中断转换时同理,引擎调用 clientAdapter.encodeError 将 Co
|
||||
│ 门面:HTTP 转换 / 透传判断 / 流式转换 │
|
||||
│ 无状态;协议识别见 §2.2 │
|
||||
├──────────────────────────────────────────────────┤
|
||||
│ HTTPRequestSpec / HTTPResponseSpec │
|
||||
│ url, method, headers, body / statusCode, ... │
|
||||
├──────────────────────────────────────────────────┤
|
||||
│ TargetProvider │
|
||||
│ base_url / api_key / model_name / adapter_config │
|
||||
├──────────────────┬───────────────────────────────┤
|
||||
│ AdapterRegistry │ MiddlewareChain │
|
||||
├──────────────────┴───────────────────────────────┤
|
||||
│ StreamConverter: Passthrough | Canonical │
|
||||
│ StreamConverter: Passthrough | SmartPassthrough | Canonical │
|
||||
├──────────────────────────────────────────────────┤
|
||||
│ ProtocolAdapter: 各协议实现 │
|
||||
│ · buildHeaders(provider) · URL 映射 │
|
||||
│ · Chat/Models/ModelInfo/Embeddings/Rerank/... 编解码 │
|
||||
│ · encodeError · StreamDecoder / StreamEncoder │
|
||||
│ · rewriteRequestModelName / rewriteResponseModelName (智能透传) │
|
||||
├──────────────────────────────────────────────────┤
|
||||
│ Canonical Model (Core + Extended) │
|
||||
├──────────────────────────────────────────────────┤
|
||||
│ Error Handling │
|
||||
│ Error Handling (分层宽容策略) │
|
||||
├──────────────────────────────────────────────────┤
|
||||
│ Utility: UTF-8 Buffer / SSE Parser / Detector │
|
||||
└──────────────────────────────────────────────────┘
|
||||
@@ -823,22 +1137,25 @@ Middleware 中断转换时同理,引擎调用 clientAdapter.encodeError 将 Co
|
||||
```
|
||||
// ─── 核心入口 ───
|
||||
ConversionEngine
|
||||
.registerAdapter(adapter)
|
||||
.use(middleware)
|
||||
.registerAdapter(adapter): void
|
||||
.use(middleware): void
|
||||
.getRegistry(): AdapterRegistry
|
||||
.isPassthrough(clientProtocol, providerProtocol): Boolean
|
||||
.convertHttpRequest(request, clientProtocol, providerProtocol, provider): HttpRequest
|
||||
.convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType): HttpResponse
|
||||
.createStreamConverter(clientProtocol, providerProtocol, provider): StreamConverter
|
||||
.convertHttpRequest(request, clientProtocol, providerProtocol, provider): HTTPRequestSpec
|
||||
.convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType, modelOverride): HTTPResponseSpec
|
||||
.createStreamConverter(clientProtocol, providerProtocol, modelOverride, interfaceType): StreamConverter
|
||||
.detectInterfaceType(nativePath, clientProtocol): InterfaceType
|
||||
.encodeError(error, clientProtocol): (body, statusCode)
|
||||
|
||||
// ─── HTTP 规格 ───
|
||||
HTTPRequestSpec { url, method, headers, body }
|
||||
HTTPResponseSpec { statusCode, headers, body }
|
||||
|
||||
// ─── 目标上游信息 ───
|
||||
TargetProvider
|
||||
.base_url: String
|
||||
.api_key: String
|
||||
.model_name: String
|
||||
.adapter_config: Map<String, Any>
|
||||
TargetProvider { base_url, api_key, model_name, adapter_config }
|
||||
|
||||
// ─── URL 路由 ───
|
||||
// 协议识别见 §2.2;出站: provider.base_url + 目标协议原生路径
|
||||
// ─── 接口类型 ───
|
||||
InterfaceType = CHAT | MODELS | MODEL_INFO | EMBEDDINGS | RERANK | PASSTHROUGH
|
||||
|
||||
// ─── 协议适配器 ───
|
||||
ProtocolAdapter
|
||||
@@ -847,19 +1164,30 @@ ProtocolAdapter
|
||||
.decodeRequest(raw) / .encodeRequest(canonical, provider)
|
||||
.decodeResponse(raw) / .encodeResponse(canonical)
|
||||
.createStreamDecoder() / .createStreamEncoder()
|
||||
.encodeError(error): RawResponse
|
||||
.encodeError(error): (body, statusCode)
|
||||
// 扩展层
|
||||
.decodeModelsResponse / .encodeModelsResponse
|
||||
.decodeModelInfoResponse / .encodeModelInfoResponse
|
||||
.decodeEmbeddingRequest / .encodeEmbeddingRequest(canonical, provider) / ...Response
|
||||
.decodeRerankRequest / .encodeRerankRequest(canonical, provider) / ...Response
|
||||
.decodeEmbeddingRequest / .encodeEmbeddingRequest / ...Response
|
||||
.decodeRerankRequest / .encodeRerankRequest / ...Response
|
||||
// 智能透传支持
|
||||
.extractUnifiedModelID(nativePath)
|
||||
.extractModelName(body, interfaceType)
|
||||
.rewriteRequestModelName(body, newModel, interfaceType)
|
||||
.rewriteResponseModelName(body, newModel, interfaceType)
|
||||
|
||||
// ─── 流式处理 ───
|
||||
StreamConverter: .processChunk(raw) / .flush()
|
||||
├─ PassthroughStreamConverter [raw] → [raw](用 provider 重建 Headers)
|
||||
└─ CanonicalStreamConverter decode → middleware → encode
|
||||
├─ PassthroughStreamConverter [raw] → [raw]
|
||||
├─ SmartPassthroughStreamConverter [raw] → [rewrite(raw)]
|
||||
└─ CanonicalStreamConverter decode → middleware → modelOverride → encode
|
||||
|
||||
// ─── 接口类型 ───
|
||||
InterfaceType = CHAT | MODELS | MODEL_INFO | EMBEDDINGS | RERANK | AUDIO | IMAGES
|
||||
// ─── 中间件 ───
|
||||
ConversionMiddleware
|
||||
.intercept(req, clientProtocol, providerProtocol, ctx): (req, error)
|
||||
.interceptStreamEvent(event, clientProtocol, providerProtocol, ctx): (event, error)
|
||||
|
||||
ConversionContext { conversionId, interfaceType, timestamp, metadata }
|
||||
```
|
||||
|
||||
---
|
||||
@@ -1070,6 +1398,28 @@ Canonical Model 是**活的公共契约**,不是固定不变的。其字段集
|
||||
| D.6 | [ ] 流式 StreamDecoder 和 StreamEncoder 已实现(对照 §4.8) |
|
||||
| D.7 | [ ] 扩展层接口的编解码已实现(支持的接口) |
|
||||
| D.8 | [ ] `encodeError` 已实现 |
|
||||
| D.10 | [ ] `extractUnifiedModelID(nativePath)` 已实现 |
|
||||
| D.10 | [ ] `extractModelName(body, interfaceType)` 已实现(覆盖 Chat/Embeddings/Rerank) |
|
||||
| D.10 | [ ] `rewriteRequestModelName` 已实现(幂等、最小化) |
|
||||
| D.10 | [ ] `rewriteResponseModelName` 已实现(按接口类型处理 model 字段存在性) |
|
||||
|
||||
### D.10 智能透传支持
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| extractUnifiedModelID | 从路径提取统一模型 ID 的规则(如 `/models/{provider_id}/{model_name}`) |
|
||||
| extractModelName | 从请求体提取 model 字段值的规则(按接口类型:Chat/Embeddings/Rerank) |
|
||||
| rewriteRequestModelName | 请求体 model 字段改写规则(最小化 JSON 手术,仅修改 model 字段) |
|
||||
| rewriteResponseModelName | 响应体 model 字段改写规则(按接口类型处理 model 字段存在性) |
|
||||
|
||||
**rewriteResponseModelName 的接口类型差异**:
|
||||
|
||||
| 接口类型 | model 字段处理 |
|
||||
|---------|---------------|
|
||||
| CHAT | 存在则改写,不存在则添加(协议要求必须有 model 字段) |
|
||||
| EMBEDDINGS | 存在则改写,不存在则添加(协议要求必须有 model 字段) |
|
||||
| RERANK | 存在则改写,不存在则不添加(model 字段可选) |
|
||||
| 其他 | 直接返回原始 body |
|
||||
|
||||
---
|
||||
|
||||
|
||||
61
docs/prompts/README.md
Normal file
61
docs/prompts/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Prompts
|
||||
|
||||
面向 AI 大模型的提示词集合,每份提示词独立可用,完整复制给 AI 工具即可启动对应流程。
|
||||
|
||||
## 命名规则
|
||||
|
||||
文件名格式:`prompt-{action}.md`,`{action}` 使用明确无歧义的英文单词或短语,用连字符连接,例如 `prompt-smart-merge.md`、`prompt-spec-review.md`。
|
||||
|
||||
## 提示词
|
||||
|
||||
| 文件 | 用途 |
|
||||
| ---- | ---- |
|
||||
| [prompt-smart-merge.md](prompt-smart-merge.md) | 批量合并 dev 分支到主干,含依赖分析、冲突处理、安全回退 |
|
||||
| [prompt-spec-review.md](prompt-spec-review.md) | 审查和重构 openspec/specs/ 下的规范文件 |
|
||||
| [prompt-proposal-review.md](prompt-proposal-review.md) | 审查 openspec 变更文档与讨论内容的一致性 |
|
||||
| [prompt-apply-review.md](prompt-apply-review.md) | 审查代码实现与 openspec 变更文档的一致性 |
|
||||
|
||||
## 书写原则
|
||||
|
||||
### 面向 AI 而非人类
|
||||
|
||||
- 不写背景知识、适用场景、定期审查节奏等"补充说明",AI 只需要执行指令
|
||||
- 不写解释性注释(为什么用 merge 不用 rebase),直接作为约束声明
|
||||
- 不写示例输出模板,AI 自行推断格式
|
||||
|
||||
### 结构
|
||||
|
||||
```
|
||||
一句话描述任务目标
|
||||
|
||||
## 约束
|
||||
全局不可违反的规则,顶部声明,不在步骤中重复
|
||||
|
||||
## 1. 收集/准备
|
||||
## 2. 分析
|
||||
## 3. 计划(用户确认)
|
||||
## 4. 执行(用户确认)
|
||||
## 5. 清理/收尾
|
||||
```
|
||||
|
||||
编号步骤,不用"第X步"(省 token)。步骤之间有确认节点的明确标注。
|
||||
|
||||
### 精简
|
||||
|
||||
- 每句话只含一条指令,不嵌套子句
|
||||
- 表格优于列表,列表优于段落
|
||||
- 规则只声明一次,不在多处重复
|
||||
- 单文件控制在 150 行以内
|
||||
|
||||
### 安全
|
||||
|
||||
- 破坏性操作(删除、重写、推送、合并提交)执行前必须用提问工具获得用户确认
|
||||
- 提供回退机制(安全锚点、备份标记、abort 路径)
|
||||
- 危险命令的约束直接写在约束块中,不用"严禁""务必"等修饰词,用"禁止"即可
|
||||
- 信息展示分层渐进(概览 → 详情 → 原始数据),避免一次输出过多内容
|
||||
|
||||
### 可操作性
|
||||
|
||||
- 给出具体命令或工具调用方式,不抽象描述("分析分支"→ "git diff --name-status target...branch")
|
||||
- 标注并行/串行:可并行的步骤明确写"并行",有副作用的操作标注"串行"
|
||||
- 用 `{占位符}` 标记需要 AI 替换的参数
|
||||
71
docs/prompts/prompt-apply-review.md
Normal file
71
docs/prompts/prompt-apply-review.md
Normal file
@@ -0,0 +1,71 @@
|
||||
审查 openspec apply 完成后的实现是否与变更文档一致,按以下流程执行。
|
||||
|
||||
## 约束
|
||||
|
||||
- 仅审查代码和文档,不修改源码
|
||||
- 每批修改建议执行前用提问工具获得用户确认
|
||||
- 涉及删除/重写操作前必须备份原文件
|
||||
|
||||
## 1. 收集
|
||||
|
||||
并行读取:
|
||||
- 本次变更涉及的所有文档(proposal.md、design.md、tasks.md、specs/*.md)
|
||||
- 实际变更的源码文件
|
||||
- 测试文件和测试结果
|
||||
- openspec/config.yaml
|
||||
|
||||
## 2. 分析
|
||||
|
||||
根据上下文判断 apply 后是否有手动改动,将实现与文档双向对照:
|
||||
|
||||
| 维度 | 检查点 |
|
||||
| ---- | ------ |
|
||||
| 目标覆盖 | 代码实现是否覆盖 proposal 中的所有目标;是否有遗漏的功能点 |
|
||||
| 方案一致性 | 实现是否与 design 中的技术方案一致;是否有偏离设计的地方 |
|
||||
| 规范遵循 | 代码是否遵循 specs 中的规范要求;是否有违反 SHALL 约束的地方 |
|
||||
| 任务完成度 | tasks 中每项任务是否真正完成;是否有未完成但标记完成的任务 |
|
||||
| 测试完整性 | 测试是否覆盖所有场景;是否有跳过、取消、降低难度的测试;测试是否真正验证了功能 |
|
||||
| 代码质量 | 是否有明显的代码问题(重复、复杂度过高、命名不清等);是否有可优化的地方 |
|
||||
|
||||
重点识别:
|
||||
- 文档要求但未实现的功能 → 需补充代码
|
||||
- 实现与文档描述不一致的地方 → 根据上下文判断:apply 后手动改动则更新文档,否则确认修正方向
|
||||
- 实现了但文档未提及的功能 → 根据上下文判断:apply 后手动改动则补充文档,否则标记为未讨论的新增
|
||||
- 标记完成但实际未完成的任务
|
||||
- 掩盖错误的测试(skip、only、降低断言强度等)
|
||||
|
||||
输出审查结果:
|
||||
1. **问题总览表**:问题类型 × 涉及文件数
|
||||
2. **逐项分析**:每个问题文件,说明问题、影响和建议
|
||||
3. **未覆盖清单**:哪些文档要求未在代码中实现(需补充代码)
|
||||
4. **不一致清单**:哪些实现与文档描述不一致(需确认修正方向)
|
||||
5. **需补充文档清单**:哪些代码改动未在文档中体现(需补充文档)
|
||||
6. **任务问题清单**:哪些任务未真正完成或标记错误
|
||||
7. **测试问题清单**:哪些测试存在问题或掩盖了错误
|
||||
8. **优化建议清单**:哪些代码可以优化
|
||||
|
||||
若所有清单均为空,输出"审查通过,未发现问题",跳至步骤 5。
|
||||
|
||||
## 3. 计划(用户确认)
|
||||
|
||||
针对发现的问题,分类提出修复方案:
|
||||
|
||||
**需补充代码**:文档要求但未实现的功能,建议补充实现。
|
||||
|
||||
**需补充文档**:代码已实现但文档未记录,且上下文确认为 apply 后手动改动,建议更新 proposal/design/tasks/specs。
|
||||
|
||||
**需确认修正方向**:实现与文档不一致,且上下文无法判断:
|
||||
- 若为 apply 后手动改动 → 更新文档
|
||||
- 否则用提问工具确认以文档还是代码为准
|
||||
|
||||
**任务和测试问题**:逐项说明未完成原因、测试掩盖错误的风险,提出修复方案。
|
||||
|
||||
用提问工具展示完整修复方案,获得用户确认后执行。
|
||||
|
||||
## 4. 执行
|
||||
|
||||
逐项执行修复方案。
|
||||
|
||||
## 5. 收尾
|
||||
|
||||
列出所有修改的文件和变更摘要。
|
||||
56
docs/prompts/prompt-proposal-review.md
Normal file
56
docs/prompts/prompt-proposal-review.md
Normal file
@@ -0,0 +1,56 @@
|
||||
审查 openspec 变更文档(proposal、design、tasks、specs)是否完整准确地记录技术方案,按以下流程执行。
|
||||
|
||||
## 约束
|
||||
|
||||
- 仅审查文档内容,不修改源码
|
||||
- 每批修改建议执行前用提问工具获得用户确认
|
||||
- 涉及删除/重写操作前必须备份原文件
|
||||
|
||||
## 1. 收集
|
||||
|
||||
并行读取:
|
||||
- 本次变更涉及的所有文档(proposal.md、design.md、tasks.md、specs/*.md)
|
||||
- 之前上下文中讨论的内容
|
||||
- openspec/config.yaml
|
||||
- 与变更规范有依赖关系的其他规范
|
||||
|
||||
## 2. 分析
|
||||
|
||||
将文档与讨论内容逐项对照,检查:
|
||||
|
||||
| 文档 | 检查点 |
|
||||
| ---- | ------ |
|
||||
| proposal.md | 是否完整记录讨论确定的目标、范围、影响;是否遗漏决策点 |
|
||||
| design.md | 是否覆盖讨论中所有技术方案;边界条件和异常处理是否与讨论一致 |
|
||||
| tasks.md | 是否覆盖 design 中所有方案;任务划分是否合理;依赖关系是否明确 |
|
||||
| specs/*.md | 是否严格遵循 OpenSpec 格式;术语是否一致;依赖声明是否完整;无实现细节混入 |
|
||||
|
||||
重点识别:
|
||||
- 讨论中确定但文档未记录的内容
|
||||
- 文档描述与讨论不一致的地方
|
||||
- 文档新增但未在讨论中提及的内容
|
||||
|
||||
输出审查结果:
|
||||
1. **问题总览表**:问题类型 × 涉及文档数
|
||||
2. **逐项分析**:每个问题文档,说明问题、影响和建议
|
||||
3. **缺失清单**:哪些技术方案/需求未在文档中体现
|
||||
4. **冲突清单**:哪些描述与其他文档或讨论结果不一致
|
||||
5. **待澄清清单**:哪些事项描述不明确,需要进一步确认
|
||||
|
||||
若所有清单均为空,输出"审查通过,未发现问题",跳至步骤 5。
|
||||
|
||||
## 3. 计划(用户确认)
|
||||
|
||||
针对发现的问题,提出修复方案(补充遗漏、修正冲突、优化表述)。
|
||||
|
||||
针对待澄清清单,用提问工具逐项向用户确认,根据反馈更新文档。
|
||||
|
||||
用提问工具展示完整修复方案,获得用户确认后执行。
|
||||
|
||||
## 4. 执行
|
||||
|
||||
逐项执行修复方案。
|
||||
|
||||
## 5. 收尾
|
||||
|
||||
列出所有修改的文件和变更摘要。
|
||||
120
docs/prompts/prompt-smart-merge.md
Normal file
120
docs/prompts/prompt-smart-merge.md
Normal file
@@ -0,0 +1,120 @@
|
||||
请对当前项目中所有 `dev*` 分支进行智能合并到目标分支(默认 main),按以下流程执行。
|
||||
|
||||
## 约束(全局,不可违反)
|
||||
|
||||
- 所有操作(合并、删除)执行前必须用提问工具获得用户确认
|
||||
- 冲突文件严禁自主编辑,仅分析方案后交用户选择
|
||||
- 全程仅使用 `git merge`,禁止 rebase(rebase 会重写目标分支历史)
|
||||
- `git add` 仅指定已解决冲突的文件路径,禁止 `git add .`/`git add -A`
|
||||
- `git reset --hard` 仅配合安全锚点 tag 使用,禁止裸用
|
||||
- 禁止自动 `git stash` `git push`
|
||||
|
||||
## 1. 环境检查
|
||||
|
||||
- `git status` 确认工作区干净,不干净则提示用户处理
|
||||
- 确认目标分支,拉取最新:`git pull`
|
||||
- 列出 dev 分支:`git branch --list 'dev*'`,无则结束
|
||||
- 创建全局安全锚点:`git tag pre-merge-backup-{timestamp}`,报告标签名
|
||||
|
||||
## 2. 分支分析
|
||||
|
||||
对每个 dev 分支并行收集:
|
||||
|
||||
| 维度 | 内容 |
|
||||
| ------ | ---------------------------------------------------------------------------------------- |
|
||||
| 基础 | 分支名、分叉 commit、commit 数/消息、是否推远端、是否已合并(`git branch --merged`) |
|
||||
| 变更 | 文件列表(`git diff --name-status target...branch`)、所属模块、行数统计 |
|
||||
| 依赖 | 是否依赖/被依赖其他 dev 分支、是否修改公共文件(共享类型、工具函数、配置) |
|
||||
| 冲突 | dry-run 预测(逐个串行,因需修改工作区):`git merge --no-commit --no-ff branch` → 收集冲突 → `git merge --abort`;与其他 dev 分支文件重叠 |
|
||||
|
||||
## 3. 合并顺序
|
||||
|
||||
按以下优先级排序:已合并(跳过) → 公共/基础设施变更 → 独立模块 → 有依赖的 → 高冲突/跨模块。
|
||||
|
||||
输出计划表(分支名、模块、文件数、依赖、预估冲突、风险),用提问工具让用户确认,用户可调整顺序或排除。
|
||||
|
||||
## 4. 逐个合并
|
||||
|
||||
对每个分支重复以下流程:
|
||||
|
||||
### 准备
|
||||
|
||||
1. 确认工作区干净、当前在目标分支
|
||||
2. `git tag merge-before-{分支名}` 创建分支级安全锚点
|
||||
3. 向用户确认即将合并的分支及风险
|
||||
|
||||
### 执行
|
||||
|
||||
`git merge {分支} --no-ff`
|
||||
|
||||
- 无冲突 → 进入验证
|
||||
- 有冲突 → 进入冲突处理
|
||||
|
||||
### 冲突处理(三层渐进)
|
||||
|
||||
`git diff --name-only --diff-filter=U` 列出冲突文件,然后按以下三层逐步展开:
|
||||
|
||||
**第一层:冲突概览表**
|
||||
|
||||
向用户展示所有冲突文件的摘要,每文件一行:
|
||||
|
||||
| 文件 | 冲突区域 | 冲突类型 | 目标分支改动 | 合并分支改动 | 推荐方案 |
|
||||
| ---- | -------- | -------- | ------------ | ------------ | -------- |
|
||||
|
||||
- 冲突类型:双方修改同一区域 / 一方删除一方修改 / 文件重命名冲突等
|
||||
- 改动描述:精简到关键差异(如"新增 rateLimit 字段"而非展示原文)
|
||||
- 推荐方案:根据分析给出最合理的选项
|
||||
|
||||
用提问工具让用户选择处理方式:
|
||||
- 批量处理:对推荐方案无异议的文件一键确认
|
||||
- 逐个审查:用户指定要详细审查的文件,进入第二层
|
||||
- 放弃合并:`git merge --abort`,跳过当前分支,继续下一个
|
||||
|
||||
**第二层:单个文件详情(按需)**
|
||||
|
||||
仅展示用户指定审查的文件:
|
||||
- 冲突区域的上下文(前后几行非冲突代码)
|
||||
- HEAD 侧与分支侧的具体差异(精简 diff,非完整文件内容)
|
||||
- 方案选项:双保留 / 保留目标(--ours) / 保留分支(--theirs) / 用户编辑 / 放弃合并
|
||||
|
||||
若用户仍觉信息不足,进入第三层。
|
||||
|
||||
**第三层:原始冲突标记(按需)**
|
||||
|
||||
展示用户指定文件的完整 `<<<<<<<`/`=======`/`>>>>>>>` 原始标记内容。
|
||||
|
||||
**确认与提交**
|
||||
|
||||
1. 方案全部确定后逐个执行:
|
||||
- 双保留:AI 生成合并后的文件内容,展示给用户确认后才写入,严禁未经确认直接写入
|
||||
- 保留目标/分支:`git checkout --ours/--theirs {file}`
|
||||
- 用户编辑:等待用户编辑完成后 `git add {file}`
|
||||
2. 逐个 add 已解决文件:`git add {file}`(禁止 `git add .`)
|
||||
3. 展示 `git diff --cached --stat`,用户确认后完成提交
|
||||
|
||||
### 验证
|
||||
|
||||
- `cd backend && go build ./...`
|
||||
- `cd frontend && bun run build`
|
||||
- 失败则提供回退选项:`git reset --soft HEAD~1` 或 `git reset --hard merge-before-{分支名}`,由用户决定
|
||||
|
||||
### 断点
|
||||
|
||||
每个分支完成后询问是否继续,暂停则记录进度。
|
||||
|
||||
## 5. 清理
|
||||
|
||||
### 删除分支
|
||||
|
||||
输出合并结果表,逐个确认删除:
|
||||
|
||||
- 本地:`git branch -d {分支}`
|
||||
- 远端:独立确认后 `git push origin --delete {分支}`
|
||||
|
||||
### 锚点
|
||||
|
||||
询问是否保留安全锚点标签。
|
||||
|
||||
### 总结
|
||||
|
||||
输出:目标分支、安全锚点标签、成功/失败/跳过数量、冲突解决文件数、已删除分支、保留分支及原因。
|
||||
58
docs/prompts/prompt-spec-review.md
Normal file
58
docs/prompts/prompt-spec-review.md
Normal file
@@ -0,0 +1,58 @@
|
||||
请对 openspec/specs/ 下所有规范文件进行审查和整理,按以下流程执行。
|
||||
|
||||
## 约束
|
||||
|
||||
- 规范描述"应该是什么",不含实现细节(具体文件路径、代码引用)和变更记录(ADDED/MODIFIED、"移除以下列"等措辞)
|
||||
- 每批重构操作执行前用提问工具获得用户确认
|
||||
- 仅删除内容已完全覆盖在其他规范中的冗余规范,非冗余内容仅迁移/合并/重命名
|
||||
|
||||
## 1. 收集
|
||||
|
||||
并行读取以下内容:
|
||||
- `openspec/specs/` 每个子目录的 `spec.md`
|
||||
- 项目源码(backend/、frontend/),理解实际实现
|
||||
- `openspec/config.yaml`,了解项目约束
|
||||
|
||||
## 2. 审查
|
||||
|
||||
将每个规范与代码和命名约定对比,按以下维度逐项检查:
|
||||
|
||||
| 维度 | 检查点 |
|
||||
| ---- | ------ |
|
||||
| 过时 | 描述的功能/组件是否仍存在于代码;引用的路径、类名、API 是否一致;交互流程是否匹配当前实现 |
|
||||
| 重复 | 不同规范是否描述同一组件/功能/场景(场景级或概念级) |
|
||||
| 错位 | 场景是否放错了功能域;Requirement 是否归属错误的规范 |
|
||||
| 合并 | 同一主题是否分散在多个规范;某个规范是否是另一个的子集 |
|
||||
| 命名 | 是否准确反映内容;是否符合命名约定(见下);是否暴露可搜索的业务关键词 |
|
||||
| 格式 | 是否使用 SHALL/WHEN/THEN;是否混入变更记录或实现细节;是否有空目录 |
|
||||
|
||||
### 命名约定
|
||||
|
||||
| 类型 | 模式 | 示例 |
|
||||
| ---- | ---- | ---- |
|
||||
| 平台专属 | `{平台}-{功能}` | admin-platform、console-my-skills |
|
||||
| 跨平台组件 | `{类别}` | component-library、layout-system |
|
||||
| 技能领域 | `skill-{方面}` | skill-market、skill-status-rules |
|
||||
| 业务功能 | `{业务名词}` | account-management、chat-scenarios |
|
||||
|
||||
命名原则:统一平台前缀(admin-/console-/developer-)、统一领域前缀(skill-)、2-3 词、避免泛化词(display/basic/general/info/data)和实现模式词(crud/list/table)。
|
||||
|
||||
## 3. 报告
|
||||
|
||||
输出分析结果:
|
||||
|
||||
1. **问题总览表**:问题类型 × 涉及规范数
|
||||
2. **逐项分析**:每个有问题的规范,说明具体问题和建议(涉及文件、冲突点、推荐操作)
|
||||
3. **重构方案**:按优先级分批:
|
||||
- P0:删除空目录和完全冗余规范
|
||||
- P1:合并重复/子集规范到主规范
|
||||
- P2:重命名不精准的规范、拆分错位内容
|
||||
- P3:修正与代码不匹配的描述、清理实现细节和变更记录
|
||||
4. **重构后目录结构**:预期的新 specs/ 目录树
|
||||
|
||||
## 4. 执行
|
||||
|
||||
逐批执行重构(P0→P3),每批执行前:
|
||||
- 展示该批次的具体操作列表(源路径 → 目标路径、操作类型)
|
||||
- 用提问工具获得用户确认
|
||||
- 执行后确认目录结构完整,规范文件可正常读取
|
||||
@@ -1,106 +0,0 @@
|
||||
# 规范文件整理流程
|
||||
|
||||
## 使用方式
|
||||
|
||||
将下方提示词完整复制给 AI 工具,即可启动一次规范文件的全面审查和整理。
|
||||
|
||||
---
|
||||
|
||||
## 提示词
|
||||
|
||||
```
|
||||
请对 openspec/specs/ 下的所有规范文件进行审查和整理,按以下流程执行:
|
||||
|
||||
## 第一步:全面阅读
|
||||
|
||||
1. 逐个读取 openspec/specs/ 下每个子目录的 spec.md,理解每个规范的覆盖范围
|
||||
2. 读取项目源码,理解实际代码实现
|
||||
3. 读取 openspec/config.yaml,了解项目约束和规范
|
||||
|
||||
## 第二步:对比分析
|
||||
|
||||
将每个规范与实际代码对比,按以下维度逐项检查:
|
||||
|
||||
### A. 过时检查
|
||||
- 规范描述的功能/组件/样式是否在当前代码中仍然存在
|
||||
- 规范引用的文件路径、类名、API 接口是否与代码一致
|
||||
- 规范描述的交互流程是否仍是当前的实现方式
|
||||
|
||||
### B. 重复检查
|
||||
- 不同规范是否描述了相同的组件/功能/场景
|
||||
- 场景级别的重复(A 规范的 Scenario 与 B 规范的 Scenario 重复)
|
||||
- 概念级别的重复(A 规范整体描述的就是 B 规范已覆盖的内容)
|
||||
|
||||
### C. 错位检查
|
||||
- A 规范中是否有场景应该属于 B 规范
|
||||
- 某个 Requirement 是否放在了错误的功能域下
|
||||
|
||||
### D. 合并检查
|
||||
- 描述同一类主题的规范是否分散在多个文件中
|
||||
- 某个规范是否可以作为子集被另一个更大的规范吸收
|
||||
|
||||
### E. 命名检查
|
||||
- 规范名称是否准确反映其实际内容
|
||||
- 命名是否遵循统一的前缀约定(平台前缀:admin- / developer- / console-)
|
||||
- 名称是否便于 AI 工具搜索匹配(暴露关键业务词和组件名)
|
||||
|
||||
### F. 格式检查
|
||||
- 是否使用标准的 SHALL/WHEN/THEN 规范格式
|
||||
- 是否混入了变更记录(如"移除以下列"、"ADDED Requirements")而非功能规范
|
||||
- 是否存在空目录
|
||||
|
||||
## 第三步:输出分析报告
|
||||
|
||||
按以下结构输出:
|
||||
|
||||
1. 问题总览表(问题类型 × 涉及规范数)
|
||||
2. 逐项分析(每个有问题的规范,说明具体问题和建议)
|
||||
3. 重构方案(删除/合并/重命名/内容调整的具体操作)
|
||||
4. 重构后的规范目录结构
|
||||
|
||||
## 第四步:执行重构
|
||||
|
||||
按优先级分批执行:
|
||||
- P0:删除空目录和完全冗余的规范
|
||||
- P1:合并重复/子集规范到主规范中
|
||||
- P2:重命名不精准的规范、拆分错位的内容
|
||||
- P3:修正与代码不匹配的细节描述
|
||||
|
||||
每步执行后确认目录结构完整。
|
||||
|
||||
## 命名约定
|
||||
|
||||
规范目录命名遵循以下规则,确保 AI 工具搜索时能精准匹配:
|
||||
|
||||
| 类型 | 命名模式 | 示例 |
|
||||
|------|---------|------|
|
||||
| 平台专属功能 | `{平台}-{功能}` | `admin-platform`、`console-my-skills`、`developer-platform` |
|
||||
| 跨平台组件/架构 | `{类别}` | `component-library`、`layout-system`、`design-tokens` |
|
||||
| 技能领域 | `skill-{方面}` | `skill-market`、`skill-status-rules`、`skill-version-management` |
|
||||
| 业务功能 | `{业务名词}` | `account-management`、`chat-scenarios` |
|
||||
|
||||
命名原则(提升 AI 检索命中率):
|
||||
- 名称中暴露可搜索的业务关键词(如 skill、modal、toast、account)
|
||||
- 同一平台的功能使用统一前缀(admin- / console- / developer-)
|
||||
- 同一领域的功能使用统一领域词前缀(skill-)
|
||||
- 避免泛化词(display → rules/behavior,basic → 删掉,general → 删掉)
|
||||
- 避免实现模式词(crud、list、table)而使用业务领域词
|
||||
- 避免同一关键词在不同规范中重复出现导致歧义(如 layout 只出现在一个规范名中)
|
||||
- 长度控制在 2-3 个词,去掉不影响检索的冗余词(info、data 等)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 补充说明
|
||||
|
||||
### 审查时的判断边界
|
||||
|
||||
- **规范 vs 代码**:规范描述"应该是什么",不描述"代码怎么写"。如果规范中出现了具体文件路径(如 `src/data/adminData.js`),通常是实现细节而非规范,应该清理
|
||||
- **规范 vs 变更记录**:规范用 SHALL/WHEN/THEN 格式描述功能需求。如果出现"移除以下列"、"保持现有样式"、"ADDED/MODIFIED Requirements"等措辞,说明混入了变更指令,需要改写
|
||||
- **规范 vs 文档**:规范不替代 README 或开发文档,不需要描述项目背景、技术选型等宏观信息
|
||||
|
||||
### 建议的定期审查节奏
|
||||
|
||||
- 每完成一批功能变更后,对照新代码检查相关规范是否需要更新
|
||||
- 规范数量超过 30 个时,建议做一次全面审查
|
||||
- 新增规范前,先搜索现有规范名称和内容,确认是否有可复用/扩展的规范
|
||||
9
embedfs/embedfs.go
Normal file
9
embedfs/embedfs.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package embedfs
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed assets/*
|
||||
var Assets embed.FS
|
||||
|
||||
//go:embed frontend-dist/*
|
||||
var FrontendDist embed.FS
|
||||
3
embedfs/go.mod
Normal file
3
embedfs/go.mod
Normal file
@@ -0,0 +1,3 @@
|
||||
module nex/embedfs
|
||||
|
||||
go 1.26.2
|
||||
1
frontend/.env.desktop
Normal file
1
frontend/.env.desktop
Normal file
@@ -0,0 +1 @@
|
||||
VITE_API_BASE=
|
||||
@@ -14,6 +14,68 @@ AI 网关管理前端,提供供应商配置和用量统计界面。
|
||||
- **样式**: SCSS Modules(禁止使用纯 CSS)
|
||||
- **测试**: Vitest + React Testing Library + Playwright
|
||||
|
||||
## API 层
|
||||
|
||||
### 字段转换机制
|
||||
|
||||
后端使用 `snake_case`,前端使用 `camelCase`,API 层自动转换:
|
||||
|
||||
```typescript
|
||||
// 发送请求时:camelCase → snake_case
|
||||
toApi({ providerId: "openai" }) // → { provider_id: "openai" }
|
||||
|
||||
// 接收响应时:snake_case → camelCase
|
||||
fromApi({ provider_id: "openai" }) // → { providerId: "openai" }
|
||||
```
|
||||
|
||||
### 统一请求函数
|
||||
|
||||
```typescript
|
||||
export async function request<T>(method: string, path: string, body?: unknown): Promise<T>
|
||||
```
|
||||
|
||||
- 自动处理字段转换
|
||||
- 自动处理 204 响应(无 body)
|
||||
- 抛出 `ApiError` 包含 `status`、`code`、`message`
|
||||
|
||||
### 错误处理
|
||||
|
||||
```typescript
|
||||
class ApiError extends Error {
|
||||
status: number; // HTTP 状态码
|
||||
code?: string; // 业务错误码
|
||||
message: string; // 错误消息
|
||||
}
|
||||
```
|
||||
|
||||
## TanStack Query 模式
|
||||
|
||||
### Query Keys 定义
|
||||
|
||||
```typescript
|
||||
// src/hooks/useProviders.ts
|
||||
export const providerKeys = {
|
||||
all: ['providers'] as const,
|
||||
};
|
||||
|
||||
// src/hooks/useModels.ts
|
||||
export const modelKeys = {
|
||||
all: ['models'] as const,
|
||||
byProvider: (providerId: string) => [...modelKeys.all, { providerId }] as const,
|
||||
};
|
||||
```
|
||||
|
||||
### Mutation 使用
|
||||
|
||||
```typescript
|
||||
const mutation = useMutation({
|
||||
mutationFn: createProvider,
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: providerKeys.all });
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
@@ -25,7 +87,7 @@ frontend/
|
||||
│ │ ├── models.ts # Model CRUD
|
||||
│ │ └── stats.ts # Stats 查询
|
||||
│ ├── components/
|
||||
│ │ └── AppLayout/ # 顶部导航布局
|
||||
│ │ └── AppLayout/ # 侧边栏导航布局
|
||||
│ ├── hooks/ # TanStack Query hooks
|
||||
│ │ ├── useProviders.ts
|
||||
│ │ ├── useModels.ts
|
||||
@@ -33,6 +95,7 @@ frontend/
|
||||
│ ├── pages/
|
||||
│ │ ├── Providers/ # 供应商管理(含内嵌模型管理)
|
||||
│ │ ├── Stats/ # 用量统计
|
||||
│ │ ├── Settings/ # 设置(开发中)
|
||||
│ │ └── NotFound.tsx
|
||||
│ ├── routes/
|
||||
│ │ └── index.tsx # 路由配置
|
||||
@@ -125,11 +188,50 @@ bun run test:e2e
|
||||
- 按模型筛选
|
||||
- 按日期范围筛选(DatePicker.RangePicker)
|
||||
|
||||
## 测试策略
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
__tests__/
|
||||
├── setup.ts # 测试配置(happy-dom)
|
||||
├── api/ # API 层测试
|
||||
│ └── client.test.ts
|
||||
├── hooks/ # TanStack Query Hook 测试
|
||||
│ ├── useProviders.test.ts
|
||||
│ └── useModels.test.ts
|
||||
└── components/ # 组件测试
|
||||
└── AppLayout.test.tsx
|
||||
```
|
||||
|
||||
### E2E 测试
|
||||
|
||||
- 位于 `e2e/` 目录
|
||||
- 使用 Playwright
|
||||
- 自动启动后端服务(临时端口 19026)
|
||||
- 配置文件:`playwright.config.ts`
|
||||
|
||||
### Mock 策略
|
||||
|
||||
- API 层测试使用 MSW(Mock Service Worker)
|
||||
- Hook 测试使用 `@testing-library/react-hooks`
|
||||
- 组件测试使用 `@testing-library/react`
|
||||
|
||||
## 环境变量
|
||||
|
||||
| 变量 | 开发环境 | 生产环境 | 说明 |
|
||||
|------|----------|----------|------|
|
||||
| `VITE_API_BASE` | (空) | `/api` | API 基础路径,空则走 Vite proxy |
|
||||
|
||||
**E2E 测试特有**:
|
||||
- `NEX_BACKEND_PORT` - E2E 后端端口(默认 19026)
|
||||
- `NEX_E2E_TEMP_DIR` - E2E 临时目录
|
||||
|
||||
## 开发规范
|
||||
|
||||
- 所有样式使用 SCSS,禁止使用纯 CSS 文件
|
||||
- 组件级样式使用 SCSS Modules(*.module.scss)
|
||||
- 图标优先使用 @ant-design/icons
|
||||
- 图标优先使用 TDesign 图标(tdesign-icons-react)
|
||||
- TypeScript strict 模式,禁止 any 类型
|
||||
- API 层自动处理 snake_case ↔ camelCase 字段转换
|
||||
- 使用路径别名 `@/` 引用 src 目录
|
||||
@@ -143,4 +245,4 @@ bun run test:e2e
|
||||
1. 在 `src/pages/` 创建页面目录和组件
|
||||
2. 在 `src/hooks/` 创建对应的 TanStack Query hook
|
||||
3. 在 `src/routes/index.tsx` 添加路由
|
||||
4. 在 `src/components/AppLayout/index.tsx` 的 menuItems 添加导航项
|
||||
4. 在 `src/components/AppLayout/index.tsx` 的 Menu 中添加 MenuItem
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import fs from 'node:fs'
|
||||
import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
import initSqlite from 'sql.js'
|
||||
|
||||
@@ -26,6 +27,18 @@ export interface SeedStatsInput {
|
||||
date: string
|
||||
}
|
||||
|
||||
export async function clearDatabase(
|
||||
request: import('@playwright/test').APIRequestContext,
|
||||
) {
|
||||
const providers = await request.get(`${API_BASE}/api/providers`)
|
||||
if (providers.ok()) {
|
||||
const data = await providers.json()
|
||||
for (const p of data) {
|
||||
await request.delete(`${API_BASE}/api/providers/${p.id}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function seedProvider(
|
||||
request: import('@playwright/test').APIRequestContext,
|
||||
data: SeedProviderInput,
|
||||
@@ -64,12 +77,14 @@ export async function seedModel(
|
||||
}
|
||||
|
||||
export async function seedUsageStats(statsData: SeedStatsInput[]) {
|
||||
const tempDir = process.env.NEX_E2E_TEMP_DIR
|
||||
if (!tempDir) {
|
||||
throw new Error('NEX_E2E_TEMP_DIR not set')
|
||||
}
|
||||
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
|
||||
|
||||
const dbPath = path.join(tempDir, 'test.db')
|
||||
|
||||
if (!fs.existsSync(dbPath)) {
|
||||
throw new Error(`Database file not found at ${dbPath}. Backend may not have created it yet.`)
|
||||
}
|
||||
|
||||
const SQL = await initSqlite()
|
||||
const buf = fs.readFileSync(dbPath)
|
||||
const db = new SQL.Database(buf)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import fs from 'node:fs'
|
||||
import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
|
||||
async function globalSetup() {
|
||||
const tempDir = process.env.NEX_E2E_TEMP_DIR
|
||||
if (tempDir && fs.existsSync(tempDir)) {
|
||||
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
|
||||
if (fs.existsSync(tempDir)) {
|
||||
console.log(`E2E temp dir: ${tempDir}`)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import fs from 'node:fs'
|
||||
import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
|
||||
async function globalTeardown() {
|
||||
const tempDir = process.env.NEX_E2E_TEMP_DIR
|
||||
if (tempDir && fs.existsSync(tempDir)) {
|
||||
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
|
||||
if (fs.existsSync(tempDir)) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
try {
|
||||
fs.rmSync(tempDir, { recursive: true, force: true })
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
import { API_BASE } from './fixtures'
|
||||
import { API_BASE, clearDatabase } from './fixtures'
|
||||
|
||||
let uid = Date.now()
|
||||
function nextId() {
|
||||
@@ -10,6 +10,7 @@ function modelFormInputs(page: import('@playwright/test').Page) {
|
||||
const dialog = page.locator('.t-dialog:visible')
|
||||
return {
|
||||
modelName: dialog.locator('input[placeholder="例如: gpt-4o"]'),
|
||||
providerSelect: dialog.locator('.t-select'),
|
||||
saveBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '保存' }),
|
||||
cancelBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '取消' }),
|
||||
}
|
||||
@@ -19,6 +20,7 @@ test.describe('模型管理', () => {
|
||||
let providerId: string
|
||||
|
||||
test.beforeEach(async ({ page, request }) => {
|
||||
await clearDatabase(request)
|
||||
providerId = nextId()
|
||||
await request.post(`${API_BASE}/api/providers`, {
|
||||
data: {
|
||||
@@ -36,24 +38,27 @@ test.describe('模型管理', () => {
|
||||
})
|
||||
|
||||
test('应能展开供应商查看模型空状态', async ({ page }) => {
|
||||
await page.locator('.t-table__expandable-icon').first().click()
|
||||
await page.locator('.t-table__expand-box').first().click()
|
||||
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
|
||||
await expect(page.getByText('暂无模型,点击上方按钮添加')).toBeVisible()
|
||||
})
|
||||
|
||||
test('应能为供应商添加模型', async ({ page }) => {
|
||||
await page.locator('.t-table__expandable-icon').first().click()
|
||||
await page.locator('.t-table__expand-box').first().click()
|
||||
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
|
||||
|
||||
await page.locator('.t-dialog:visible').waitFor({ state: 'hidden', timeout: 3000 }).catch(() => {})
|
||||
|
||||
await page.locator('.t-table__expanded-row button:has-text("添加模型")').first().click()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
|
||||
const inputs = modelFormInputs(page)
|
||||
await inputs.modelName.fill('gpt_4_turbo')
|
||||
|
||||
const responsePromise = page.waitForResponse(resp => resp.url().includes('/api/models') && resp.request().method() === 'POST')
|
||||
await inputs.saveBtn.click()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
|
||||
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4_turbo')).toBeVisible()
|
||||
await responsePromise
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4_turbo', { exact: true })).toBeVisible({ timeout: 5000 })
|
||||
})
|
||||
|
||||
test('应显示统一模型 ID', async ({ page, request }) => {
|
||||
@@ -68,7 +73,7 @@ test.describe('模型管理', () => {
|
||||
await page.reload()
|
||||
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
|
||||
|
||||
await page.locator('.t-table__expandable-icon').first().click()
|
||||
await page.locator('.t-table__expand-box').first().click()
|
||||
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
|
||||
|
||||
await expect(page.locator('.t-table__expanded-row').getByText(`${providerId}/claude_3`)).toBeVisible()
|
||||
@@ -86,7 +91,7 @@ test.describe('模型管理', () => {
|
||||
await page.reload()
|
||||
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
|
||||
|
||||
await page.locator('.t-table__expandable-icon').first().click()
|
||||
await page.locator('.t-table__expand-box').first().click()
|
||||
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
|
||||
|
||||
await page.locator('.t-table__expanded-row button:has-text("编辑")').first().click()
|
||||
@@ -95,10 +100,11 @@ test.describe('模型管理', () => {
|
||||
const inputs = modelFormInputs(page)
|
||||
await inputs.modelName.clear()
|
||||
await inputs.modelName.fill('gpt_4o')
|
||||
|
||||
const responsePromise = page.waitForResponse(resp => resp.url().includes('/api/models') && resp.request().method() === 'PUT')
|
||||
await inputs.saveBtn.click()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
|
||||
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4o')).toBeVisible()
|
||||
await responsePromise
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4o', { exact: true })).toBeVisible({ timeout: 5000 })
|
||||
})
|
||||
|
||||
test('应能删除模型', async ({ page, request }) => {
|
||||
@@ -113,13 +119,13 @@ test.describe('模型管理', () => {
|
||||
await page.reload()
|
||||
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
|
||||
|
||||
await page.locator('.t-table__expandable-icon').first().click()
|
||||
await page.locator('.t-table__expand-box').first().click()
|
||||
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model')).toBeVisible()
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model', { exact: true })).toBeVisible()
|
||||
|
||||
await page.locator('.t-table__expanded-row button:has-text("删除")').first().click()
|
||||
await expect(page.getByText(/确定要删除/)).toBeVisible()
|
||||
await page.locator('.t-popconfirm').getByRole('button', { name: '确定' }).click()
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model')).not.toBeVisible({ timeout: 5000 })
|
||||
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model', { exact: true })).not.toBeVisible({ timeout: 5000 })
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
import { clearDatabase } from './fixtures'
|
||||
|
||||
let uid = Date.now()
|
||||
|
||||
@@ -11,15 +12,17 @@ function formInputs(page: import('@playwright/test').Page) {
|
||||
return {
|
||||
id: dialog.locator('input[placeholder="例如: openai"]'),
|
||||
name: dialog.locator('input[placeholder="例如: OpenAI"]'),
|
||||
apiKey: dialog.locator('input[type="password"]'),
|
||||
apiKey: dialog.locator('input[placeholder="sk-..."]'),
|
||||
baseUrl: dialog.locator('input[placeholder="例如: https://api.openai.com/v1"]'),
|
||||
protocol: dialog.locator('.t-select'),
|
||||
saveBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '保存' }),
|
||||
cancelBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '取消' }),
|
||||
}
|
||||
}
|
||||
|
||||
test.describe('供应商管理', () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
test.beforeEach(async ({ page, request }) => {
|
||||
await clearDatabase(request)
|
||||
await page.goto('/providers')
|
||||
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
|
||||
})
|
||||
@@ -34,12 +37,14 @@ test.describe('供应商管理', () => {
|
||||
await inputs.name.fill('Test Provider')
|
||||
await inputs.apiKey.fill('sk_test_key_12345')
|
||||
await inputs.baseUrl.fill('https://api.openai.com/v1')
|
||||
await inputs.protocol.click()
|
||||
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
|
||||
await page.locator('.t-select__dropdown .t-select-option').first().click()
|
||||
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
|
||||
|
||||
await inputs.saveBtn.click()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
|
||||
|
||||
await expect(page.locator('.t-table__body td').getByText(testId)).toBeVisible()
|
||||
await expect(page.locator('.t-table__body td').getByText('Test Provider')).toBeVisible()
|
||||
|
||||
await expect(page.locator('.t-table__body').getByText('Test Provider')).toBeVisible({ timeout: 10000 })
|
||||
})
|
||||
|
||||
test('应能编辑供应商并验证更新生效', async ({ page }) => {
|
||||
@@ -51,8 +56,15 @@ test.describe('供应商管理', () => {
|
||||
await inputs.name.fill('Before Edit')
|
||||
await inputs.apiKey.fill('sk_key')
|
||||
await inputs.baseUrl.fill('https://api.example.com/v1')
|
||||
await inputs.protocol.click()
|
||||
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
|
||||
await page.locator('.t-select__dropdown .t-select-option').first().click()
|
||||
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
|
||||
|
||||
const responsePromise = page.waitForResponse(resp => resp.url().includes('/api/providers') && resp.request().method() === 'POST')
|
||||
await inputs.saveBtn.click()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
|
||||
await responsePromise
|
||||
await expect(page.locator('.t-table__body').getByText('Before Edit')).toBeVisible({ timeout: 5000 })
|
||||
|
||||
await page.locator('.t-table__body button:has-text("编辑")').first().click()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
@@ -60,10 +72,11 @@ test.describe('供应商管理', () => {
|
||||
const editInputs = formInputs(page)
|
||||
await editInputs.name.clear()
|
||||
await editInputs.name.fill('After Edit')
|
||||
|
||||
const updateResponsePromise = page.waitForResponse(resp => resp.url().includes('/api/providers') && resp.request().method() === 'PUT')
|
||||
await editInputs.saveBtn.click()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
|
||||
|
||||
await expect(page.locator('.t-table__body td').getByText('After Edit')).toBeVisible()
|
||||
await updateResponsePromise
|
||||
await expect(page.locator('.t-table__body').getByText('After Edit')).toBeVisible({ timeout: 5000 })
|
||||
})
|
||||
|
||||
test('应能删除供应商并验证消失', async ({ page }) => {
|
||||
@@ -75,18 +88,20 @@ test.describe('供应商管理', () => {
|
||||
await inputs.name.fill('To Delete')
|
||||
await inputs.apiKey.fill('sk_key')
|
||||
await inputs.baseUrl.fill('https://api.example.com/v1')
|
||||
await inputs.protocol.click()
|
||||
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
|
||||
await page.locator('.t-select__dropdown .t-select-option').first().click()
|
||||
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
|
||||
await inputs.saveBtn.click()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
|
||||
await expect(page.locator('.t-table__body').getByText('To Delete')).toBeVisible({ timeout: 10000 })
|
||||
|
||||
await page.locator('.t-table__body button:has-text("删除")').first().click()
|
||||
await expect(page.getByText('确定要删除这个供应商吗?')).toBeVisible()
|
||||
await page.locator('.t-popconfirm').getByRole('button', { name: '确定' }).click()
|
||||
await expect(page.getByText('确定要删除这个供应商吗?')).not.toBeVisible({ timeout: 3000 })
|
||||
|
||||
await expect(page.locator('.t-table__body td').getByText(testId)).not.toBeVisible({ timeout: 5000 })
|
||||
await expect(page.locator('.t-table__body').getByText('To Delete')).not.toBeVisible({ timeout: 5000 })
|
||||
})
|
||||
|
||||
test('应正确脱敏显示 API Key', async ({ page }) => {
|
||||
test('应正确显示完整 API Key', async ({ page }) => {
|
||||
const testId = nextId()
|
||||
await page.getByRole('button', { name: '添加供应商' }).click()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
@@ -95,9 +110,13 @@ test.describe('供应商管理', () => {
|
||||
await inputs.name.fill('Mask Test')
|
||||
await inputs.apiKey.fill('sk_abcdefghijklmnopqrstuvwxyz')
|
||||
await inputs.baseUrl.fill('https://api.example.com/v1')
|
||||
await inputs.protocol.click()
|
||||
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
|
||||
await page.locator('.t-select__dropdown .t-select-option').first().click()
|
||||
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
|
||||
await inputs.saveBtn.click()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
|
||||
await expect(page.locator('.t-table__body').getByText('Mask Test')).toBeVisible({ timeout: 10000 })
|
||||
|
||||
await expect(page.locator('.t-table__body')).toContainText('***stuv')
|
||||
await expect(page.locator('.t-table__body')).toContainText('sk_abcdefghijklmnopqrstuvwxyz')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
import { API_BASE, seedUsageStats } from './fixtures'
|
||||
import { API_BASE, seedUsageStats, clearDatabase } from './fixtures'
|
||||
|
||||
test.describe('统计概览', () => {
|
||||
test.beforeAll(async ({ request }) => {
|
||||
await clearDatabase(request)
|
||||
const p1 = `sp1_${Date.now()}`
|
||||
const p2 = `sp2_${Date.now()}`
|
||||
process.env._STATS_P1 = p1
|
||||
@@ -65,6 +66,7 @@ test.describe('统计概览', () => {
|
||||
|
||||
test.describe('统计筛选', () => {
|
||||
test.beforeAll(async ({ request }) => {
|
||||
await clearDatabase(request)
|
||||
const p1 = `fp1_${Date.now()}`
|
||||
const p2 = `fp2_${Date.now()}`
|
||||
process.env._FILTER_P1 = p1
|
||||
|
||||
@@ -5,7 +5,7 @@ function formInputs(page: import('@playwright/test').Page) {
|
||||
return {
|
||||
id: dialog.locator('input[placeholder="例如: openai"]'),
|
||||
name: dialog.locator('input[placeholder="例如: OpenAI"]'),
|
||||
apiKey: dialog.locator('input[type="password"]'),
|
||||
apiKey: dialog.locator('input[placeholder="sk-..."]'),
|
||||
baseUrl: dialog.locator('input[placeholder="例如: https://api.openai.com/v1"]'),
|
||||
saveBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '保存' }),
|
||||
cancelBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '取消' }),
|
||||
@@ -20,7 +20,7 @@ test.describe('供应商表单验证', () => {
|
||||
|
||||
test('应显示必填字段验证', async ({ page }) => {
|
||||
await page.getByRole('button', { name: '添加供应商' }).click()
|
||||
await expect(page.locator('.t-dialog')).toBeVisible()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
|
||||
await formInputs(page).saveBtn.click()
|
||||
|
||||
@@ -32,7 +32,7 @@ test.describe('供应商表单验证', () => {
|
||||
|
||||
test('应验证URL格式', async ({ page }) => {
|
||||
await page.getByRole('button', { name: '添加供应商' }).click()
|
||||
await expect(page.locator('.t-dialog')).toBeVisible()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
|
||||
const inputs = formInputs(page)
|
||||
await inputs.id.fill('test_url')
|
||||
@@ -46,17 +46,17 @@ test.describe('供应商表单验证', () => {
|
||||
|
||||
test('取消后表单应重置', async ({ page }) => {
|
||||
await page.getByRole('button', { name: '添加供应商' }).click()
|
||||
await expect(page.locator('.t-dialog')).toBeVisible()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
|
||||
let inputs = formInputs(page)
|
||||
await inputs.id.fill('should_be_reset')
|
||||
await inputs.name.fill('Should Be Reset')
|
||||
|
||||
await inputs.cancelBtn.click()
|
||||
await expect(page.locator('.t-dialog')).not.toBeVisible()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible()
|
||||
|
||||
await page.getByRole('button', { name: '添加供应商' }).click()
|
||||
await expect(page.locator('.t-dialog')).toBeVisible()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
|
||||
inputs = formInputs(page)
|
||||
await expect(inputs.id).toHaveValue('')
|
||||
@@ -65,11 +65,11 @@ test.describe('供应商表单验证', () => {
|
||||
|
||||
test('快速连续点击只打开一个对话框', async ({ page }) => {
|
||||
await page.getByRole('button', { name: '添加供应商' }).click()
|
||||
await expect(page.locator('.t-dialog')).toBeVisible()
|
||||
await expect(page.locator('.t-dialog:visible')).toBeVisible()
|
||||
|
||||
expect(await page.locator('.t-dialog').count()).toBe(1)
|
||||
expect(await page.locator('.t-dialog:visible').count()).toBe(1)
|
||||
|
||||
await formInputs(page).cancelBtn.click()
|
||||
await expect(page.locator('.t-dialog')).not.toBeVisible()
|
||||
await expect(page.locator('.t-dialog:visible')).not.toBeVisible()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -8,7 +8,10 @@ const __filename = fileURLToPath(import.meta.url)
|
||||
const __dirname = path.dirname(__filename)
|
||||
|
||||
const E2E_PORT = 19026
|
||||
const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'nex-e2e-'))
|
||||
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
|
||||
if (!fs.existsSync(path.join(tempDir, 'test.db'))) {
|
||||
fs.rmSync(tempDir, { recursive: true, force: true })
|
||||
}
|
||||
const dbPath = path.join(tempDir, 'test.db')
|
||||
const logPath = path.join(tempDir, 'log')
|
||||
|
||||
@@ -24,11 +27,12 @@ export default defineConfig({
|
||||
fullyParallel: false,
|
||||
forbidOnly: !!process.env.CI,
|
||||
retries: process.env.CI ? 2 : 0,
|
||||
workers: process.env.CI ? 1 : undefined,
|
||||
workers: 1,
|
||||
reporter: 'html',
|
||||
use: {
|
||||
baseURL: 'http://localhost:5173',
|
||||
trace: 'on-first-retry',
|
||||
storageState: undefined,
|
||||
},
|
||||
projects: [
|
||||
{
|
||||
@@ -50,6 +54,9 @@ export default defineConfig({
|
||||
command: 'bun run dev',
|
||||
url: 'http://localhost:5173',
|
||||
reuseExistingServer: false,
|
||||
env: {
|
||||
NEX_BACKEND_PORT: String(E2E_PORT),
|
||||
},
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
@@ -16,7 +16,10 @@ const queryClient = new QueryClient({
|
||||
function App() {
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<ConfigProvider globalConfig={{}}>
|
||||
<ConfigProvider globalConfig={{
|
||||
animation: { include: ['ripple', 'expand', 'fade'] },
|
||||
table: { size: 'medium' },
|
||||
}}>
|
||||
<BrowserRouter>
|
||||
<AppRoutes />
|
||||
</BrowserRouter>
|
||||
|
||||
@@ -125,9 +125,6 @@ describe('ModelForm', () => {
|
||||
const dialog = getDialog();
|
||||
expect(within(dialog).getByText('编辑模型')).toBeInTheDocument();
|
||||
|
||||
// Check that unified ID field is displayed
|
||||
expect(within(dialog).getByText('统一模型 ID')).toBeInTheDocument();
|
||||
|
||||
// Check model name input
|
||||
const modelNameInput = within(dialog).getByPlaceholderText('例如: gpt-4o') as HTMLInputElement;
|
||||
expect(modelNameInput.value).toBe('gpt-4o');
|
||||
|
||||
@@ -63,13 +63,16 @@ describe('ProviderForm', () => {
|
||||
|
||||
const baseUrlInput = within(dialog).getByPlaceholderText('例如: https://api.openai.com/v1') as HTMLInputElement;
|
||||
expect(baseUrlInput.value).toBe('https://api.openai.com/v1');
|
||||
|
||||
const apiKeyInput = within(dialog).getByPlaceholderText('sk-...') as HTMLInputElement;
|
||||
expect(apiKeyInput.value).toBe('sk-old-key');
|
||||
});
|
||||
|
||||
it('shows API Key label variant in edit mode', () => {
|
||||
it('shows API Key label in edit mode', () => {
|
||||
render(<ProviderForm {...defaultProps} provider={mockProvider} />);
|
||||
|
||||
const dialog = getDialog();
|
||||
expect(within(dialog).getByText('API Key(留空则不修改)')).toBeInTheDocument();
|
||||
expect(within(dialog).getByText('API Key')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows validation error messages for required fields', async () => {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import userEvent from '@testing-library/user-event';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { ProviderTable } from '@/pages/Providers/ProviderTable';
|
||||
@@ -48,19 +48,18 @@ const defaultProps = {
|
||||
};
|
||||
|
||||
describe('ProviderTable', () => {
|
||||
it('renders provider list with name, baseUrl, masked apiKey, and status tags', () => {
|
||||
it('renders provider list with name, baseUrl, apiKey, and status tags', () => {
|
||||
render(<ProviderTable {...defaultProps} />);
|
||||
|
||||
expect(screen.getByText('供应商列表')).toBeInTheDocument();
|
||||
|
||||
// Check that provider names appear (they will appear in both name column and potentially protocol column)
|
||||
expect(screen.getAllByText('OpenAI').length).toBeGreaterThan(0);
|
||||
expect(screen.getByText('https://api.openai.com/v1')).toBeInTheDocument();
|
||||
expect(screen.getByText('****5678')).toBeInTheDocument();
|
||||
expect(screen.getByText('sk-abcdefgh12345678')).toBeInTheDocument();
|
||||
|
||||
expect(screen.getAllByText('Anthropic').length).toBeGreaterThan(0);
|
||||
expect(screen.getByText('https://api.anthropic.com')).toBeInTheDocument();
|
||||
expect(screen.getByText('****test')).toBeInTheDocument();
|
||||
expect(screen.getByText('sk-ant-test')).toBeInTheDocument();
|
||||
|
||||
const enabledTags = screen.getAllByText('启用');
|
||||
const disabledTags = screen.getAllByText('禁用');
|
||||
@@ -77,7 +76,7 @@ describe('ProviderTable', () => {
|
||||
expect(container.querySelector('.t-card__body')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders short api keys fully masked', () => {
|
||||
it('renders short api keys directly', () => {
|
||||
const shortKeyProvider: Provider[] = [
|
||||
{
|
||||
...mockProviders[0],
|
||||
@@ -88,7 +87,7 @@ describe('ProviderTable', () => {
|
||||
];
|
||||
render(<ProviderTable {...defaultProps} providers={shortKeyProvider} />);
|
||||
|
||||
expect(screen.getByText('****')).toBeInTheDocument();
|
||||
expect(screen.getByText('ab')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onAdd when clicking "添加供应商" button', async () => {
|
||||
|
||||
@@ -37,55 +37,19 @@ describe('StatCards', () => {
|
||||
expect(screen.getByText('今日请求量')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calculates total requests correctly', () => {
|
||||
render(<StatCards stats={mockStats} />);
|
||||
|
||||
const totalRequests = mockStats.reduce((sum, s) => sum + s.requestCount, 0);
|
||||
expect(screen.getByText(totalRequests.toString())).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calculates active models correctly', () => {
|
||||
render(<StatCards stats={mockStats} />);
|
||||
|
||||
const activeModels = new Set(mockStats.map((s) => s.modelName)).size;
|
||||
const valueElements = screen.getAllByText(activeModels.toString());
|
||||
expect(valueElements.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('calculates active providers correctly', () => {
|
||||
render(<StatCards stats={mockStats} />);
|
||||
|
||||
const activeProviders = new Set(mockStats.map((s) => s.providerId)).size;
|
||||
const valueElements = screen.getAllByText(activeProviders.toString());
|
||||
expect(valueElements.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('renders with empty stats', () => {
|
||||
render(<StatCards stats={[]} />);
|
||||
|
||||
expect(screen.getByText('总请求量')).toBeInTheDocument();
|
||||
const zeroValues = screen.getAllByText('0');
|
||||
expect(zeroValues.length).toBeGreaterThan(0);
|
||||
expect(screen.getByText('活跃模型数')).toBeInTheDocument();
|
||||
expect(screen.getByText('活跃供应商数')).toBeInTheDocument();
|
||||
expect(screen.getByText('今日请求量')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calculates today requests correctly', () => {
|
||||
const today = new Date().toISOString().split('T')[0];
|
||||
const statsWithToday: UsageStats[] = [
|
||||
...mockStats,
|
||||
{
|
||||
id: 4,
|
||||
providerId: 'openai',
|
||||
modelName: 'gpt-4o',
|
||||
requestCount: 50,
|
||||
date: today,
|
||||
},
|
||||
];
|
||||
it('renders suffix units', () => {
|
||||
render(<StatCards stats={mockStats} />);
|
||||
|
||||
render(<StatCards stats={statsWithToday} />);
|
||||
|
||||
const todayRequests = statsWithToday
|
||||
.filter((s) => s.date === today)
|
||||
.reduce((sum, s) => sum + s.requestCount, 0);
|
||||
expect(screen.getByText(todayRequests.toString())).toBeInTheDocument();
|
||||
expect(screen.getAllByText('次').length).toBeGreaterThan(0);
|
||||
expect(screen.getAllByText('个').length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,8 +6,8 @@ import type { UsageStats } from '@/types';
|
||||
// Mock Recharts components
|
||||
vi.mock('recharts', () => ({
|
||||
ResponsiveContainer: vi.fn(({ children }) => <div data-testid="mock-chart-container">{children}</div>),
|
||||
LineChart: vi.fn(() => <div data-testid="mock-line-chart" />),
|
||||
Line: vi.fn(() => null),
|
||||
AreaChart: vi.fn(() => <div data-testid="mock-area-chart" />),
|
||||
Area: vi.fn(() => null),
|
||||
XAxis: vi.fn(() => null),
|
||||
YAxis: vi.fn(() => null),
|
||||
CartesianGrid: vi.fn(() => null),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Layout, Menu } from 'tdesign-react';
|
||||
import { ServerIcon, ChartLineIcon, SettingIcon } from 'tdesign-icons-react';
|
||||
import { useState } from 'react';
|
||||
import { Layout, Menu, Button } from 'tdesign-react';
|
||||
import { ServerIcon, ChartLineIcon, SettingIcon, ChevronLeftIcon, ChevronRightIcon } from 'tdesign-icons-react';
|
||||
import { Outlet, useLocation, useNavigate } from 'react-router';
|
||||
|
||||
const { MenuItem } = Menu;
|
||||
@@ -7,6 +8,7 @@ const { MenuItem } = Menu;
|
||||
export function AppLayout() {
|
||||
const location = useLocation();
|
||||
const navigate = useNavigate();
|
||||
const [collapsed, setCollapsed] = useState(false);
|
||||
|
||||
const getPageTitle = () => {
|
||||
if (location.pathname === '/providers') return '供应商管理';
|
||||
@@ -15,10 +17,12 @@ export function AppLayout() {
|
||||
return 'AI Gateway';
|
||||
};
|
||||
|
||||
const asideWidth = collapsed ? '64px' : '232px';
|
||||
|
||||
return (
|
||||
<Layout style={{ minHeight: '100vh' }}>
|
||||
<Layout.Aside
|
||||
width="232px"
|
||||
width={asideWidth}
|
||||
style={{
|
||||
overflow: 'hidden',
|
||||
height: '100vh',
|
||||
@@ -28,45 +32,52 @@ export function AppLayout() {
|
||||
bottom: 0,
|
||||
}}
|
||||
>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
|
||||
<div
|
||||
style={{
|
||||
<Menu
|
||||
value={location.pathname}
|
||||
onChange={(value) => navigate(value as string)}
|
||||
collapsed={collapsed}
|
||||
width={['232px', '64px']}
|
||||
logo={
|
||||
<div style={{
|
||||
height: 64,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
fontSize: '1.25rem',
|
||||
fontWeight: 600,
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
AI Gateway
|
||||
</div>
|
||||
<Menu
|
||||
value={location.pathname}
|
||||
onChange={(value) => navigate(value as string)}
|
||||
style={{ flex: 1, overflow: 'auto' }}
|
||||
>
|
||||
<MenuItem value="/providers" icon={<ServerIcon />}>
|
||||
供应商管理
|
||||
</MenuItem>
|
||||
<MenuItem value="/stats" icon={<ChartLineIcon />}>
|
||||
用量统计
|
||||
</MenuItem>
|
||||
<MenuItem value="/settings" icon={<SettingIcon />}>
|
||||
设置
|
||||
</MenuItem>
|
||||
</Menu>
|
||||
</div>
|
||||
}}>
|
||||
{!collapsed && 'AI Gateway'}
|
||||
</div>
|
||||
}
|
||||
operations={
|
||||
<Button
|
||||
variant="text"
|
||||
shape="square"
|
||||
icon={collapsed ? <ChevronRightIcon /> : <ChevronLeftIcon />}
|
||||
onClick={() => setCollapsed(!collapsed)}
|
||||
/>
|
||||
}
|
||||
style={{ height: '100%' }}
|
||||
>
|
||||
<MenuItem value="/providers" icon={<ServerIcon />}>
|
||||
供应商管理
|
||||
</MenuItem>
|
||||
<MenuItem value="/stats" icon={<ChartLineIcon />}>
|
||||
用量统计
|
||||
</MenuItem>
|
||||
<MenuItem value="/settings" icon={<SettingIcon />}>
|
||||
设置
|
||||
</MenuItem>
|
||||
</Menu>
|
||||
</Layout.Aside>
|
||||
<Layout style={{ marginLeft: 232 }}>
|
||||
<Layout style={{ marginLeft: asideWidth }}>
|
||||
<Layout.Header
|
||||
style={{
|
||||
padding: '0 2rem',
|
||||
background: '#fff',
|
||||
background: 'var(--td-bg-color-container)',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
borderBottom: '1px solid #e7e7e7',
|
||||
borderBottom: '1px solid var(--td-component-stroke)',
|
||||
}}
|
||||
>
|
||||
<h1 style={{ margin: 0, fontSize: '1.25rem' }}>{getPageTitle()}</h1>
|
||||
|
||||
@@ -7,3 +7,20 @@ body,
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
/* TDesign 主题微调 */
|
||||
:root {
|
||||
/* 页面背景色 */
|
||||
--td-bg-color-page: #f5f7fa;
|
||||
|
||||
/* 圆角调大 */
|
||||
--td-radius-default: 6px;
|
||||
--td-radius-medium: 9px;
|
||||
--td-radius-large: 12px;
|
||||
--td-radius-extraLarge: 16px;
|
||||
|
||||
/* 系统字体栈 */
|
||||
--td-font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
|
||||
"Helvetica Neue", Arial, "Noto Sans", sans-serif, "Apple Color Emoji",
|
||||
"Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { StrictMode } from 'react'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import 'tdesign-react/es/style/index.css'
|
||||
import 'tdesign-react/es/_util/react-19-adapter'
|
||||
import './index.scss'
|
||||
import App from './App'
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ export function NotFound() {
|
||||
minHeight: '100vh',
|
||||
padding: '2rem',
|
||||
}}>
|
||||
<h1 style={{ fontSize: '6rem', margin: 0, color: '#999' }}>404</h1>
|
||||
<p style={{ fontSize: '1.25rem', color: '#666', marginBottom: '2rem' }}>
|
||||
<h1 style={{ fontSize: '6rem', margin: 0, color: 'var(--td-text-color-placeholder)' }}>404</h1>
|
||||
<p style={{ fontSize: '1.25rem', color: 'var(--td-text-color-secondary)', marginBottom: '2rem' }}>
|
||||
抱歉,您访问的页面不存在。
|
||||
</p>
|
||||
<Button theme="primary" onClick={() => navigate('/providers')}>
|
||||
|
||||
@@ -14,7 +14,7 @@ interface ModelFormProps {
|
||||
model?: Model;
|
||||
providerId: string;
|
||||
providers: Provider[];
|
||||
onSave: (values: ModelFormValues) => void;
|
||||
onSave: (values: ModelFormValues) => Promise<void> | void;
|
||||
onCancel: () => void;
|
||||
loading: boolean;
|
||||
}
|
||||
@@ -63,23 +63,18 @@ export function ModelForm({
|
||||
<Dialog
|
||||
header={isEdit ? '编辑模型' : '添加模型'}
|
||||
visible={open}
|
||||
placement="center"
|
||||
width="520px"
|
||||
closeOnOverlayClick={false}
|
||||
closeOnEscKeydown={false}
|
||||
lazy={false}
|
||||
onConfirm={() => { form?.submit(); return false; }}
|
||||
onClose={onCancel}
|
||||
confirmLoading={loading}
|
||||
confirmBtn="保存"
|
||||
cancelBtn="取消"
|
||||
destroyOnClose
|
||||
>
|
||||
<Form form={form} layout="vertical" onSubmit={handleSubmit}>
|
||||
{isEdit && model?.unifiedId && (
|
||||
<Form.FormItem label="统一模型 ID">
|
||||
<Input value={model.unifiedId} disabled />
|
||||
<div style={{ color: '#999', fontSize: 12, marginTop: 4 }}>
|
||||
格式:provider_id/model_name
|
||||
</div>
|
||||
</Form.FormItem>
|
||||
)}
|
||||
|
||||
<Form.FormItem
|
||||
label="供应商"
|
||||
name="providerId"
|
||||
|
||||
@@ -31,7 +31,11 @@ export function ModelTable({ providerId, onAdd, onEdit }: ModelTableProps) {
|
||||
colKey: 'enabled',
|
||||
width: 80,
|
||||
cell: ({ row }) =>
|
||||
row.enabled ? <Tag theme="success">启用</Tag> : <Tag theme="danger">禁用</Tag>,
|
||||
row.enabled ? (
|
||||
<Tag theme="success" variant="light" shape="round">启用</Tag>
|
||||
) : (
|
||||
<Tag theme="danger" variant="light" shape="round">禁用</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '操作',
|
||||
@@ -72,6 +76,7 @@ export function ModelTable({ providerId, onAdd, onEdit }: ModelTableProps) {
|
||||
data={models}
|
||||
rowKey="id"
|
||||
loading={isLoading}
|
||||
stripe
|
||||
pagination={undefined}
|
||||
size="small"
|
||||
empty="暂无模型,点击上方按钮添加"
|
||||
|
||||
@@ -15,7 +15,7 @@ interface ProviderFormValues {
|
||||
interface ProviderFormProps {
|
||||
open: boolean;
|
||||
provider?: Provider;
|
||||
onSave: (values: ProviderFormValues) => void;
|
||||
onSave: (values: ProviderFormValues) => Promise<void> | void;
|
||||
onCancel: () => void;
|
||||
loading: boolean;
|
||||
}
|
||||
@@ -30,26 +30,23 @@ export function ProviderForm({
|
||||
const [form] = Form.useForm();
|
||||
const isEdit = !!provider;
|
||||
|
||||
// 当弹窗打开或provider变化时,设置表单值
|
||||
useEffect(() => {
|
||||
if (open && form) {
|
||||
if (provider) {
|
||||
// 编辑模式:设置现有值
|
||||
form.setFieldsValue({
|
||||
id: provider.id,
|
||||
name: provider.name,
|
||||
apiKey: '',
|
||||
apiKey: provider.apiKey,
|
||||
baseUrl: provider.baseUrl,
|
||||
protocol: provider.protocol,
|
||||
enabled: provider.enabled,
|
||||
});
|
||||
} else {
|
||||
// 新增模式:重置表单
|
||||
form.reset();
|
||||
form.setFieldsValue({ enabled: true, protocol: 'openai' });
|
||||
}
|
||||
}
|
||||
}, [open, provider]); // 移除form依赖,避免循环
|
||||
}, [open, provider]);
|
||||
|
||||
const handleSubmit = (context: SubmitContext) => {
|
||||
if (context.validateResult === true && form) {
|
||||
@@ -62,12 +59,16 @@ export function ProviderForm({
|
||||
<Dialog
|
||||
header={isEdit ? '编辑供应商' : '添加供应商'}
|
||||
visible={open}
|
||||
placement="center"
|
||||
width="520px"
|
||||
closeOnOverlayClick={false}
|
||||
closeOnEscKeydown={false}
|
||||
lazy={false}
|
||||
onConfirm={() => { form?.submit(); return false; }}
|
||||
onClose={onCancel}
|
||||
confirmLoading={loading}
|
||||
confirmBtn="保存"
|
||||
cancelBtn="取消"
|
||||
destroyOnClose
|
||||
>
|
||||
<Form form={form} layout="vertical" onSubmit={handleSubmit}>
|
||||
<Form.FormItem label="ID" name="id" rules={[{ required: true, message: '请输入供应商 ID' }]}>
|
||||
@@ -79,11 +80,11 @@ export function ProviderForm({
|
||||
</Form.FormItem>
|
||||
|
||||
<Form.FormItem
|
||||
label={isEdit ? 'API Key(留空则不修改)' : 'API Key'}
|
||||
label="API Key"
|
||||
name="apiKey"
|
||||
rules={isEdit ? [] : [{ required: true, message: '请输入 API Key' }]}
|
||||
rules={[{ required: true, message: '请输入 API Key' }]}
|
||||
>
|
||||
<Input type="password" placeholder="sk-..." autocomplete="current-password" />
|
||||
<Input placeholder="sk-..." />
|
||||
</Form.FormItem>
|
||||
|
||||
<Form.FormItem
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, Table, Tag, Popconfirm, Space, Card, Tooltip } from 'tdesign-react';
|
||||
import { Button, Table, Tag, Popconfirm, Space, Card } from 'tdesign-react';
|
||||
import type { PrimaryTableCol } from 'tdesign-react/es/table/type';
|
||||
import type { Provider, Model } from '@/types';
|
||||
import { ModelTable } from './ModelTable';
|
||||
@@ -13,12 +13,6 @@ interface ProviderTableProps {
|
||||
onEditModel: (model: Model) => void;
|
||||
}
|
||||
|
||||
function maskApiKey(key: string | null | undefined): string {
|
||||
if (!key) return '****';
|
||||
if (key.length <= 4) return '****';
|
||||
return `****${key.slice(-4)}`;
|
||||
}
|
||||
|
||||
export function ProviderTable({
|
||||
providers,
|
||||
loading,
|
||||
@@ -45,7 +39,7 @@ export function ProviderTable({
|
||||
colKey: 'protocol',
|
||||
width: 100,
|
||||
cell: ({ row }) => (
|
||||
<Tag theme={row.protocol === 'openai' ? 'primary' : 'success'}>
|
||||
<Tag theme={row.protocol === 'openai' ? 'primary' : 'success'} variant="light" shape="round">
|
||||
{row.protocol === 'openai' ? 'OpenAI' : 'Anthropic'}
|
||||
</Tag>
|
||||
),
|
||||
@@ -53,20 +47,18 @@ export function ProviderTable({
|
||||
{
|
||||
title: 'API Key',
|
||||
colKey: 'apiKey',
|
||||
width: 120,
|
||||
ellipsis: true,
|
||||
cell: ({ row }) => (
|
||||
<Tooltip content={maskApiKey(row.apiKey)}>
|
||||
<span>{maskApiKey(row.apiKey)}</span>
|
||||
</Tooltip>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '状态',
|
||||
colKey: 'enabled',
|
||||
width: 80,
|
||||
cell: ({ row }) =>
|
||||
row.enabled ? <Tag theme="success">启用</Tag> : <Tag theme="danger">禁用</Tag>,
|
||||
row.enabled ? (
|
||||
<Tag theme="success" variant="light" shape="round">启用</Tag>
|
||||
) : (
|
||||
<Tag theme="danger" variant="light" shape="round">禁用</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '操作',
|
||||
@@ -93,6 +85,8 @@ export function ProviderTable({
|
||||
return (
|
||||
<Card
|
||||
title="供应商列表"
|
||||
headerBordered
|
||||
hoverShadow
|
||||
actions={
|
||||
<Button theme="primary" onClick={onAdd}>
|
||||
添加供应商
|
||||
@@ -104,6 +98,7 @@ export function ProviderTable({
|
||||
data={providers}
|
||||
rowKey="id"
|
||||
loading={loading}
|
||||
stripe
|
||||
expandedRow={({ row }) => (
|
||||
<ModelTable
|
||||
providerId={row.id}
|
||||
|
||||
@@ -50,21 +50,21 @@ export function ProvidersPage() {
|
||||
open={providerFormOpen}
|
||||
provider={editingProvider}
|
||||
loading={createProvider.isPending || updateProvider.isPending}
|
||||
onSave={(values) => {
|
||||
if (editingProvider) {
|
||||
const input: Partial<UpdateProviderInput> = {};
|
||||
if (values.name !== editingProvider.name) input.name = values.name;
|
||||
if (values.apiKey) input.apiKey = values.apiKey;
|
||||
if (values.baseUrl !== editingProvider.baseUrl) input.baseUrl = values.baseUrl;
|
||||
if (values.enabled !== editingProvider.enabled) input.enabled = values.enabled;
|
||||
updateProvider.mutate(
|
||||
{ id: editingProvider.id, input },
|
||||
{ onSuccess: () => setProviderFormOpen(false) },
|
||||
);
|
||||
} else {
|
||||
createProvider.mutate(values, {
|
||||
onSuccess: () => setProviderFormOpen(false),
|
||||
});
|
||||
onSave={async (values) => {
|
||||
try {
|
||||
if (editingProvider) {
|
||||
const input: Partial<UpdateProviderInput> = {};
|
||||
if (values.name !== editingProvider.name) input.name = values.name;
|
||||
if (values.apiKey !== editingProvider.apiKey) input.apiKey = values.apiKey;
|
||||
if (values.baseUrl !== editingProvider.baseUrl) input.baseUrl = values.baseUrl;
|
||||
if (values.enabled !== editingProvider.enabled) input.enabled = values.enabled;
|
||||
await updateProvider.mutateAsync({ id: editingProvider.id, input });
|
||||
} else {
|
||||
await createProvider.mutateAsync(values);
|
||||
}
|
||||
setProviderFormOpen(false);
|
||||
} catch {
|
||||
// 错误已由 hooks 的 onError 处理
|
||||
}
|
||||
}}
|
||||
onCancel={() => setProviderFormOpen(false)}
|
||||
@@ -76,20 +76,20 @@ export function ProvidersPage() {
|
||||
providerId={modelFormProviderId}
|
||||
providers={providers}
|
||||
loading={createModel.isPending || updateModel.isPending}
|
||||
onSave={(values) => {
|
||||
if (editingModel) {
|
||||
const input: Partial<UpdateModelInput> = {};
|
||||
if (values.providerId !== editingModel.providerId) input.providerId = values.providerId;
|
||||
if (values.modelName !== editingModel.modelName) input.modelName = values.modelName;
|
||||
if (values.enabled !== editingModel.enabled) input.enabled = values.enabled;
|
||||
updateModel.mutate(
|
||||
{ id: editingModel.id, input },
|
||||
{ onSuccess: () => setModelFormOpen(false) },
|
||||
);
|
||||
} else {
|
||||
createModel.mutate(values, {
|
||||
onSuccess: () => setModelFormOpen(false),
|
||||
});
|
||||
onSave={async (values) => {
|
||||
try {
|
||||
if (editingModel) {
|
||||
const input: Partial<UpdateModelInput> = {};
|
||||
if (values.providerId !== editingModel.providerId) input.providerId = values.providerId;
|
||||
if (values.modelName !== editingModel.modelName) input.modelName = values.modelName;
|
||||
if (values.enabled !== editingModel.enabled) input.enabled = values.enabled;
|
||||
await updateModel.mutateAsync({ id: editingModel.id, input });
|
||||
} else {
|
||||
await createModel.mutateAsync(values);
|
||||
}
|
||||
setModelFormOpen(false);
|
||||
} catch {
|
||||
// 错误已由 hooks 的 onError 处理
|
||||
}
|
||||
}}
|
||||
onCancel={() => setModelFormOpen(false)}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Card } from 'tdesign-react';
|
||||
export function SettingsPage() {
|
||||
return (
|
||||
<Card title="设置">
|
||||
<div style={{ textAlign: 'center', padding: '40px 0', color: '#999' }}>
|
||||
<div style={{ textAlign: 'center', padding: '40px 0', color: 'var(--td-text-color-placeholder)' }}>
|
||||
设置功能开发中...
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Row, Col, Card, Statistic } from 'tdesign-react';
|
||||
import { ChartBarIcon, ChartLineIcon, ServerIcon, Calendar1Icon } from 'tdesign-icons-react';
|
||||
import type { UsageStats } from '@/types';
|
||||
|
||||
interface StatCardsProps {
|
||||
@@ -18,23 +19,55 @@ export function StatCards({ stats }: StatCardsProps) {
|
||||
return (
|
||||
<Row gutter={[16, 16]} style={{ marginBottom: 16 }}>
|
||||
<Col xs={12} md={6}>
|
||||
<Card>
|
||||
<Statistic title="总请求量" value={totalRequests} />
|
||||
<Card bordered={false} hoverShadow>
|
||||
<Statistic
|
||||
title="总请求量"
|
||||
value={totalRequests}
|
||||
color="blue"
|
||||
prefix={<ChartBarIcon />}
|
||||
suffix="次"
|
||||
animation={{ duration: 800, valueFrom: 0 }}
|
||||
animationStart
|
||||
/>
|
||||
</Card>
|
||||
</Col>
|
||||
<Col xs={12} md={6}>
|
||||
<Card>
|
||||
<Statistic title="活跃模型数" value={activeModels} />
|
||||
<Card bordered={false} hoverShadow>
|
||||
<Statistic
|
||||
title="活跃模型数"
|
||||
value={activeModels}
|
||||
color="green"
|
||||
prefix={<ChartLineIcon />}
|
||||
suffix="个"
|
||||
animation={{ duration: 800, valueFrom: 0 }}
|
||||
animationStart
|
||||
/>
|
||||
</Card>
|
||||
</Col>
|
||||
<Col xs={12} md={6}>
|
||||
<Card>
|
||||
<Statistic title="活跃供应商数" value={activeProviders} />
|
||||
<Card bordered={false} hoverShadow>
|
||||
<Statistic
|
||||
title="活跃供应商数"
|
||||
value={activeProviders}
|
||||
color="orange"
|
||||
prefix={<ServerIcon />}
|
||||
suffix="个"
|
||||
animation={{ duration: 800, valueFrom: 0 }}
|
||||
animationStart
|
||||
/>
|
||||
</Card>
|
||||
</Col>
|
||||
<Col xs={12} md={6}>
|
||||
<Card>
|
||||
<Statistic title="今日请求量" value={todayRequests} />
|
||||
<Card bordered={false} hoverShadow>
|
||||
<Statistic
|
||||
title="今日请求量"
|
||||
value={todayRequests}
|
||||
color="red"
|
||||
prefix={<Calendar1Icon />}
|
||||
suffix="次"
|
||||
animation={{ duration: 800, valueFrom: 0 }}
|
||||
animationStart
|
||||
/>
|
||||
</Card>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
@@ -78,8 +78,8 @@ export function StatsTable({
|
||||
};
|
||||
|
||||
return (
|
||||
<Card title="统计数据">
|
||||
<Space style={{ marginBottom: 16 }}>
|
||||
<Card title="统计数据" headerBordered hoverShadow>
|
||||
<Space style={{ marginBottom: 16 }} size="medium" breakLine>
|
||||
<Select
|
||||
clearable
|
||||
placeholder="所有供应商"
|
||||
@@ -107,6 +107,7 @@ export function StatsTable({
|
||||
data={stats}
|
||||
rowKey="id"
|
||||
loading={loading}
|
||||
stripe
|
||||
pagination={{ pageSize: 20 }}
|
||||
empty="暂无统计数据"
|
||||
/>
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import { Card } from 'tdesign-react';
|
||||
import { LineChart, Line, XAxis, YAxis, CartesianGrid, ResponsiveContainer, Tooltip } from 'recharts';
|
||||
import { AreaChart, Area, XAxis, YAxis, CartesianGrid, ResponsiveContainer, Tooltip } from 'recharts';
|
||||
import type { UsageStats } from '@/types';
|
||||
|
||||
interface UsageChartProps {
|
||||
stats: UsageStats[];
|
||||
isLoading?: boolean;
|
||||
}
|
||||
|
||||
export function UsageChart({ stats }: UsageChartProps) {
|
||||
export function UsageChart({ stats, isLoading }: UsageChartProps) {
|
||||
const chartData = Object.entries(
|
||||
stats.reduce<Record<string, number>>((acc, s) => {
|
||||
acc[s.date] = (acc[s.date] || 0) + s.requestCount;
|
||||
@@ -17,25 +18,31 @@ export function UsageChart({ stats }: UsageChartProps) {
|
||||
.sort((a, b) => a.date.localeCompare(b.date));
|
||||
|
||||
return (
|
||||
<Card title="请求趋势" style={{ marginBottom: 16 }}>
|
||||
<Card title="请求趋势" headerBordered hoverShadow loading={isLoading} style={{ marginBottom: 16 }}>
|
||||
{chartData.length > 0 ? (
|
||||
<ResponsiveContainer width="100%" height={300}>
|
||||
<LineChart data={chartData}>
|
||||
<CartesianGrid strokeDasharray="3 3" />
|
||||
<AreaChart data={chartData}>
|
||||
<defs>
|
||||
<linearGradient id="requestGradient" x1="0" y1="0" x2="0" y2="1">
|
||||
<stop offset="0%" stopColor="#0052D9" stopOpacity={0.4} />
|
||||
<stop offset="100%" stopColor="#0052D9" stopOpacity={0} />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<CartesianGrid strokeDasharray="3 3" stroke="#e8e8e8" />
|
||||
<XAxis dataKey="date" />
|
||||
<YAxis />
|
||||
<Tooltip />
|
||||
<Line
|
||||
<Area
|
||||
type="monotone"
|
||||
dataKey="requestCount"
|
||||
stroke="#0052D9"
|
||||
strokeWidth={2}
|
||||
dot={{ fill: '#0052D9' }}
|
||||
fill="url(#requestGradient)"
|
||||
/>
|
||||
</LineChart>
|
||||
</AreaChart>
|
||||
</ResponsiveContainer>
|
||||
) : (
|
||||
<div style={{ textAlign: 'center', padding: '40px 0', color: '#999' }}>
|
||||
<div style={{ textAlign: 'center', padding: '40px 0', color: 'var(--td-text-color-placeholder)' }}>
|
||||
暂无数据
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -27,7 +27,7 @@ export function StatsPage() {
|
||||
return (
|
||||
<div>
|
||||
<StatCards stats={stats} />
|
||||
<UsageChart stats={stats} />
|
||||
<UsageChart stats={stats} isLoading={isLoading} />
|
||||
<StatsTable
|
||||
providers={providers}
|
||||
stats={stats}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user