1
0

24 Commits

Author SHA1 Message Date
380586afa6 Merge branch 'dev-update-readme' 2026-04-22 19:43:01 +08:00
ebb70809bf Merge branch 'dev-app' 2026-04-22 19:38:29 +08:00
7399afbc5c Merge branch 'dev-frontend-style-optimization' 2026-04-22 19:37:30 +08:00
c0669e4b07 Merge branch 'dev-database-write-optimization' 2026-04-22 19:36:38 +08:00
05c04091b3 Merge branch 'dev-add-review-prompt' 2026-04-22 19:35:47 +08:00
0b05e08705 feat: 新增桌面应用支持
- 新增 desktop 应用入口,将后端与前端打包为单一可执行文件
- 集成系统托盘功能(getlantern/systray)
- 支持单实例锁和端口冲突检测
- 启动时自动打开浏览器显示管理界面
- 新增 embedfs 模块嵌入静态资源
- 新增跨平台构建脚本(macOS/Windows/Linux)
- 新增 macOS .app 打包脚本
- 统一 Makefile,移除 backend/Makefile
- 更新 README 添加桌面应用使用说明
2026-04-22 19:27:27 +08:00
df253559a5 feat(cache): 实现 RoutingCache 和 StatsBuffer 优化数据库写入
- 新增 RoutingCache 组件,使用 sync.Map 缓存 Provider 和 Model
- 新增 StatsBuffer 组件,使用 sync.Map + atomic.Int64 缓冲统计数据
- 扩展 StatsRepository.BatchUpdate 支持批量增量更新
- 改造 RoutingService/StatsService/ProviderService/ModelService 集成缓存
- 更新 usage-statistics spec,新增 routing-cache 和 stats-buffer spec
- 新增单元测试覆盖缓存命中/失效/并发场景
2026-04-22 19:24:36 +08:00
669cbb8c51 feat(prompts): 添加 proposal-review 和 apply-review 审查提示词 2026-04-22 19:12:03 +08:00
5ae9d85272 style: 优化前端样式,提升现代化设计感
- ConfigProvider 注入全局配置(动画、表格尺寸)
- CSS Variables 主题微调(页面背景、圆角、字体栈)
- AppLayout Menu 支持 logo/operations/collapsed
- Statistic 组件增加 color/prefix/suffix/animation
- Card 组件启用 hoverShadow/headerBordered
- Table 组件启用 stripe 斑马纹
- Tag 组件使用 variant="light" + shape="round"
- Dialog 居中显示并设置固定宽度
- 布局样式硬编码颜色替换为 TDesign Token
- UsageChart 改用 AreaChart + 渐变填充
- 更新 frontend spec 同步样式体系要求
2026-04-22 18:09:22 +08:00
72aebef625 docs: 更新三份 README 文档以反映实际项目情况 2026-04-22 16:24:38 +08:00
f5e45d032e refactor: 重命名提示词文件为英文 prompt-xxx 格式,优化智能合并提示词 2026-04-22 15:34:41 +08:00
b03e5f809f Merge branch 'dev-initial-scritps' into master 2026-04-22 15:26:14 +08:00
ec563aaa16 docs: 优化 prompts 提示词,面向 AI 精简 token 并新增书写原则 2026-04-22 15:21:43 +08:00
873f09d3bf refactor(scripts): 拆分脚本为 init/ 和 detect/ 子目录,优化 init-llm.sh 2026-04-22 14:57:14 +08:00
5e7267db07 fix(e2e): 修复 10 个被 skip 的 E2E 测试
- 将 playwright.config.ts 的 mkdtemp 替换为固定路径,解决主进程/worker 临时目录不一致问题
- 交换后端 WAL 与迁移执行顺序,确保 sql.js 能读取到完整 schema
- 修复 models.spec.ts 断言使用 exact:true 避免统一模型 ID 列干扰
- 移除全部 10 个 test.skip,26 个 E2E 测试全部通过
2026-04-22 14:32:12 +08:00
7b28cee7a1 Merge branch 'dev-frontend-optimization' 2026-04-22 13:31:37 +08:00
934c8dea77 Merge branch 'dev-testcase-analysis' 2026-04-22 13:23:32 +08:00
7d91fe345e Merge branch 'dev-conversion-docs' 2026-04-22 13:23:15 +08:00
4e86adffb7 feat: 系统性改进后端测试体系
- 新增 6 个测试场景 (config load pipe, handler errors, service aggregation, engine degradation, openai decoder edges, negative tests)
- 更新测试工具规格 (mockgen, in-memory SQLite)
- 覆盖率目标从 >80% 提升至 >85%
- 新增 test-unit 和 test-integration Makefile 命令
- 新增死代码清理和 mockgen 需求
- 归档变更至 openspec/changes/archive/2026-04-22-improve-backend-testing/
2026-04-22 13:18:51 +08:00
5d58acf5a6 fix: 修复供应商管理弹窗交互问题并去掉 API Key 脱敏
- Dialog 设置 lazy={false} 修复首次打开编辑弹窗表单为空
- API Key 改为普通字段(前端去掉 password 类型,后端去掉掩码逻辑)
- 删除模型编辑弹窗中的统一模型 ID 字段
- 简化 ProviderService.Get 签名(去掉 maskKey 参数)
- 删除 domain 和 config 层的 MaskAPIKey() 方法
- 更新前后端测试(107 单元测试 + 16 E2E 全部通过)
- 同步 delta spec 到主 spec
2026-04-22 13:13:25 +08:00
81dcecb723 docs: 补充 bun 作为前端唯一包管理器的说明 2026-04-22 11:37:05 +08:00
141f5f886f fix: 修复供应商管理弹窗交互问题
- 导入 TDesign react-19-adapter 修复 MessagePlugin 在 React 19 下的渲染错误
- Dialog 禁用蒙版点击和 ESC 键关闭,防止误操作丢失表单数据
- 重构弹窗关闭逻辑,使用 mutateAsync 替代 useEffect 监听 isSuccess
- 成功后自动关闭弹窗,失败后保持弹窗打开并显示错误提示
2026-04-22 11:36:16 +08:00
7fa5af483b docs: 更新 conversion 设计文档以匹配代码实现
主要更新内容:
- 新增三车道数据流模型(透传/智能透传/完整转换)
- 补充 ProtocolAdapter 智能透传相关 4 个方法
- 更新 InterfaceType 枚举(移除 AUDIO/IMAGES,增加 PASSTHROUGH)
- 新增 HTTPRequestSpec/HTTPResponseSpec 类型定义
- 更新引擎方法签名(增加 modelOverride 和 interfaceType 参数)
- 明确接口类型分发策略和中间件应用范围
- 新增三种流式转换器变体
- 重写错误处理策略为分层宽容策略
- 标记多模态和扩展点为 Deferred
- 更新附录 B 接口速查和附录 D 协议适配清单
2026-04-22 10:54:30 +08:00
f488b9cc15 fix(e2e): 修复对话框关闭问题,完善 E2E 测试
- 修复 TDesign Dialog onConfirm 不自动关闭的问题
- 使用 useEffect 监听 mutation 状态自动关闭对话框
- 测试使用 waitForResponse 等待 API 响应
- 添加 clearDatabase 函数确保测试隔离
- 归档 e2e-real-backend 变更到 archive/2026-04-22
- 同步 e2e-testing spec 到主 specs
2026-04-22 10:32:57 +08:00
121 changed files with 7866 additions and 1774 deletions

7
.gitignore vendored
View File

@@ -405,4 +405,9 @@ openspec/changes/archive
temp
.agents
skills-lock.json
.worktrees
.worktrees
!scripts/build/
# Embedfs generated
embedfs/assets/
embedfs/frontend-dist/

115
Makefile Normal file
View File

@@ -0,0 +1,115 @@
.PHONY: all clean \
backend-build backend-run backend-test backend-test-unit backend-test-integration backend-test-coverage \
backend-lint backend-deps backend-generate \
backend-migrate-up backend-migrate-down backend-migrate-status backend-migrate-create \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint \
desktop desktop-darwin desktop-windows desktop-linux package-macos
# ============================================
# 后端
# ============================================
all: backend-build
backend-build:
cd backend && go build -o bin/server ./cmd/server
backend-run:
cd backend && go run ./cmd/server
backend-test:
cd backend && go test ./... -v
backend-test-unit:
cd backend && go test ./internal/... ./pkg/... -v
backend-test-integration:
cd backend && go test ./tests/... -v
backend-test-coverage:
cd backend && go test ./... -coverprofile=coverage.out
cd backend && go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: backend/coverage.html"
backend-lint:
cd backend && go tool golangci-lint run ./...
backend-deps:
cd backend && go mod tidy
backend-generate:
cd backend && go generate ./...
backend-migrate-up:
cd backend && goose -dir migrations sqlite3 $(DB_PATH) up
backend-migrate-down:
cd backend && goose -dir migrations sqlite3 $(DB_PATH) down
backend-migrate-status:
cd backend && goose -dir migrations sqlite3 $(DB_PATH) status
backend-migrate-create:
@read -p "Migration name: " name; \
cd backend && goose -dir migrations create $$name sql
# ============================================
# 前端
# ============================================
frontend-build:
cd frontend && bun install && bun run build
frontend-dev:
cd frontend && bun dev
frontend-test:
cd frontend && bun run test
frontend-test-watch:
cd frontend && bun run test:watch
frontend-test-coverage:
cd frontend && bun run test:coverage
frontend-test-e2e:
cd frontend && bun run test:e2e
frontend-lint:
cd frontend && bun run lint
# ============================================
# 桌面应用
# ============================================
desktop: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 go build -o ../build/nex ./cmd/desktop
frontend-build-desktop:
cd frontend && cp .env.desktop .env.production.local && bun install && bun run build && rm -f .env.production.local
embedfs-prepare:
rm -rf embedfs/assets embedfs/frontend-dist
cp -r assets embedfs/assets
cp -r frontend/dist embedfs/frontend-dist
desktop-darwin: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-darwin-arm64 ./cmd/desktop
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-darwin-amd64 ./cmd/desktop
desktop-windows: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-windows-amd64.exe ./cmd/desktop
desktop-linux: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop
package-macos:
./scripts/build/package-macos.sh
# ============================================
# 清理
# ============================================
clean:
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
rm -rf build/

148
README.md
View File

@@ -7,13 +7,15 @@
```
nex/
├── backend/ # Go 后端服务(分层架构)
│ ├── cmd/server/ # 主程序入口
│ ├── cmd/
│ │ ├── server/ # CLI 主程序入口
│ │ └── desktop/ # 桌面应用入口
│ ├── internal/
│ │ ├── handler/ # HTTP 处理器 + 中间件
│ │ ├── service/ # 业务逻辑层
│ │ ├── repository/ # 数据访问层
│ │ ├── domain/ # 领域模型
│ │ ├── protocol/ # 协议适配器OpenAI/Anthropic
│ │ ├── conversion/ # 协议转换引擎OpenAI/Anthropic 适配器
│ │ ├── provider/ # 供应商客户端
│ │ └── config/ # 配置管理
│ ├── pkg/ # 公共包logger/errors/validator
@@ -32,16 +34,28 @@ nex/
│ ├── e2e/ # Playwright E2E 测试
│ └── package.json
├── assets/ # 应用资源
│ ├── icon.png # 托盘图标
│ ├── AppIcon.icns # macOS 应用图标
│ └── icon.ico # Windows 应用图标
├── scripts/ # 构建脚本
│ └── build/
│ └── package-macos.sh # macOS .app 打包脚本
└── README.md # 本文件
```
## 功能特性
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
- **跨协议转换**Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
- **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`
- **透明代理**:对 OpenAI 兼容供应商 Smart Passthrough最小化改写保持参数保真
- **流式响应**:完整支持 SSE 流式传输
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
- **Function Calling**支持工具调用Tools
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
- **扩展接口**:支持 Embeddings 和 Rerank 接口
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
- **用量统计**:按供应商、模型、日期统计请求数量
- **Web 配置界面**:提供供应商和模型配置管理
@@ -54,7 +68,7 @@ nex/
- **ORM**: GORM
- **数据库**: SQLite
- **日志**: zap + lumberjack结构化日志 + 日志轮转)
- **配置**: gopkg.in/yaml.v3
- **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值)
- **验证**: go-playground/validator/v10
- **迁移**: goose
@@ -72,7 +86,46 @@ nex/
## 快速开始
### 后端
### 桌面应用(推荐)
**构建桌面应用**
```bash
# 当前平台
make desktop
# macOS (arm64 + amd64)
make desktop-darwin
make package-macos # 打包为 .app
# Windows
make desktop-windows
# Linux
make desktop-linux
```
**使用桌面应用**
- 双击启动应用macOS: Nex.appWindows: nex.exeLinux: nex
- 系统托盘图标出现,浏览器自动打开管理界面
- 点击托盘图标显示菜单,可打开管理界面或退出
- 关闭浏览器后服务继续运行,可通过托盘重新打开
**注意事项**
- 桌面应用需要 CGO 支持
- macOS: 自带 Xcode Command Line Tools
- Linux: 自带 gcc部分桌面环境需要 `libappindicator3-dev`
- Windows: 需要 MinGW-w64 或在 Windows 环境构建
**Linux 桌面环境兼容性**
- GNOME: 需要 AppIndicator 扩展
- KDE Plasma: 原生支持
- Xfce: 需要 libappindicator
- 其他支持 StatusNotifierItem 规范的环境
### CLI 模式
#### 后端
```bash
cd backend
@@ -100,12 +153,19 @@ bun dev
### 代理接口(对外部应用)
代理接口统一使用 `provider_id/model_name` 格式的模型 ID(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
- `POST /v1/chat/completions` - OpenAI Chat Completions API
- `POST /v1/messages` - Anthropic Messages API
- `GET /v1/models` - 模型列表(本地数据库聚合,不请求上游
- `GET /v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
**OpenAI 协议**`protocol=openai`
- `POST /openai/chat/completions` - 对话补全
- `GET /openai/models` - 模型列表(本地数据库聚合)
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
- `POST /openai/embeddings` - 嵌入
- `POST /openai/rerank` - 重排序
**Anthropic 协议**`protocol=anthropic`
- `POST /anthropic/v1/messages` - 消息对话
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
### 管理接口(对前端)
@@ -131,6 +191,10 @@ bun dev
## 配置
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
### 配置文件
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成:
```yaml
@@ -154,48 +218,52 @@ log:
compress: true
```
数据文件:
### 环境变量
所有配置项支持环境变量,使用 `NEX_` 前缀:
```bash
export NEX_SERVER_PORT=9000
export NEX_DATABASE_PATH=/data/nex.db
export NEX_LOG_LEVEL=debug
```
命名规则:配置路径转大写 + 下划线(如 `server.port``NEX_SERVER_PORT`)。
### CLI 参数
```bash
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
```
命名规则:配置路径转 kebab-case`server.port``--server-port`)。
### 数据文件
- `~/.nex/config.yaml` - 配置文件
- `~/.nex/config.db` - SQLite 数据库
- `~/.nex/log/` - 日志目录
## 测试
### 后端测试
```bash
cd backend
make test # 运行所有测试
make test-coverage # 生成覆盖率报告
```
### 前端测试
```bash
cd frontend
bun run test # 单元测试 + 组件测试
bun run test:watch # 监听模式
bun run test:coverage # 生成覆盖率报告
bun run test:e2e # E2E 测试
make backend-test # 后端测试
make backend-test-coverage # 后端覆盖率
make frontend-test # 前端测试
make frontend-test-e2e # 前端 E2E 测试
```
## 开发
### 后端开发
```bash
cd backend
make build # 构建
make lint # 代码检查
make migrate-up # 数据库迁移
```
make backend-build # 构建后端
make backend-run # 运行后端
make backend-lint # 后端代码检查
make backend-migrate-up # 数据库迁移
### 前端开发
```bash
cd frontend
bun run build # 构建生产版本
bun run lint # 代码检查
make frontend-build # 构建前端
make frontend-dev # 前端开发模式
make frontend-lint # 前端代码检查
```
## 开发规范

BIN
assets/AppIcon.icns Normal file

Binary file not shown.

64
assets/README.md Normal file
View File

@@ -0,0 +1,64 @@
# Assets
应用资源文件目录。
## 文件说明
| 文件 | 用途 | 尺寸 | 格式 |
|------|------|------|------|
| `icon.svg` | 源图标 | 64x64 | SVG |
| `icon.png` | 托盘图标 | 64x64 | PNG |
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
## 替换图标
### 1. 准备图标
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
### 2. 生成各平台图标
**托盘图标 (PNG)**
```bash
magick your-icon.svg -resize 64x64 icon.png
```
**macOS 应用图标 (ICNS)**
```bash
mkdir icon.iconset
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
iconutil -c icns icon.iconset -o AppIcon.icns
rm -rf icon.iconset
```
**Windows 应用图标 (ICO)**
```bash
magick your-icon.svg -resize 256x256 icon.ico
```
### 3. 替换文件
将生成的文件放入此目录,然后重新构建桌面应用:
```bash
./scripts/build/build-darwin-arm64.sh
```
## macOS Template 图标
macOS 支持 Template 图标,自动适配深浅色模式:
- 使用黑色 + 透明设计
- 文件名以 `Template` 结尾(如 `iconTemplate.png`
- 黑色在深色模式下自动变为白色
## 设计建议
- 托盘图标应简洁,在小尺寸下清晰可辨
- 避免过多细节和文字
- 使用高对比度颜色
- macOS 建议使用 Template 图标风格

BIN
assets/icon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

BIN
assets/icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

13
assets/icon.svg Normal file
View File

@@ -0,0 +1,13 @@
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
<circle cx="32" cy="32" r="6" fill="white"/>
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
<circle cx="20" cy="20" r="3" fill="white"/>
<circle cx="44" cy="20" r="3" fill="white"/>
<circle cx="20" cy="44" r="3" fill="white"/>
<circle cx="44" cy="44" r="3" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 779 B

View File

@@ -1,45 +0,0 @@
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint
# 构建
build:
go build -o bin/server ./cmd/server
# 运行
run:
go run ./cmd/server
# 测试
test:
go test ./... -v
# 测试覆盖率
test-coverage:
go test ./... -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: coverage.html"
# 清理
clean:
rm -rf bin/ coverage.out coverage.html
# 数据库迁移
migrate-up:
goose -dir migrations sqlite3 $(DB_PATH) up
migrate-down:
goose -dir migrations sqlite3 $(DB_PATH) down
migrate-status:
goose -dir migrations sqlite3 $(DB_PATH) status
migrate-create:
@read -p "Migration name: " name; \
goose -dir migrations create $$name sql
# 代码检查
lint:
golangci-lint run ./...
# 安装依赖
deps:
go mod tidy

View File

@@ -116,8 +116,11 @@ backend/
├── migrations/ # 数据库迁移
│ └── 20260421000001_initial_schema.sql
├── tests/ # 集成测试
│ ├── helpers.go
── integration/
│ ├── helpers.go # 测试辅助函数
── config/ # 测试配置
│ ├── integration/ # 集成测试
│ │ └── e2e_conversion_test.go # E2E 协议转换测试
│ └── mocks/ # Mock 实现
├── Makefile
├── go.mod
└── README.md
@@ -146,6 +149,120 @@ Client Request (clientProtocol)
同协议时自动透传,跳过序列化开销。
## 协议转换架构
### Canonical Model 中间表示
所有协议转换都经过 Canonical Model 中间表示层,实现 Hub-and-Spoke 架构:
```
OpenAI Request → Canonical Request → Anthropic Request
(中间表示)
OpenAI Response ← Canonical Response ← Anthropic Response
```
**CanonicalRequest 核心字段**
- `Model` - 统一模型 ID
- `Messages` - 消息列表(支持 text、tool_use、tool_result、thinking 类型)
- `Tools` - 工具定义
- `Thinking` - 推理配置(`budget_tokens``effort`
- `Parameters` - 通用参数(`max_tokens``temperature``top_p` 等)
### Smart Passthrough 机制
同协议请求走 Smart Passthrough 路径,**零序列化开销**
```
1. 检测 clientProtocol == providerProtocol
2. 仅改写请求体中的 model 字段unified_id → upstream_model_name
3. 直接转发请求到上游
4. 响应中仅改写 model 字段upstream_model_name → unified_id
```
### 流式转换器层次
```
StreamConverter (接口)
├── PassthroughStreamConverter # 直接透传,无任何处理
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model
└── CanonicalStreamConverter # 跨协议完整转换decode → encode
```
### InterfaceType 枚举
| 类型 | 说明 |
|------|------|
| `CHAT` | 对话补全chat/completions、messages |
| `MODELS` | 模型列表 |
| `MODEL_INFO` | 模型详情 |
| `EMBEDDINGS` | 嵌入接口 |
| `RERANK` | 重排序接口 |
| `PASSTHROUGH` | 未知接口,直接透传 |
## 协议适配器特性
### OpenAI 适配器
**特有字段支持**
- `reasoning_effort` - 映射到 Canonical Thinking 配置(`none` → 禁用,其他 → `effort`
- `reasoning_content` - 非标准字段,映射到 Canonical thinking 块
- `max_completion_tokens` - 新字段,优先于 `max_tokens`
- `refusal` - 非标准字段,作为 text 块处理
**废弃字段兼容**
- `functions` / `function_call` - 自动转换为 `tools` / `tool_choice`
**消息处理**
- 合并连续同角色消息Anthropic 不支持连续同角色)
- 工具选择映射:`any``required`
### Anthropic 适配器
**特有字段支持**
- `thinking` - 推理配置(`type: enabled``budget_tokens``effort`
- `output_config` - 结构化输出配置
- `disable_parallel_tool_use` - 禁用并行工具调用
- `container` - 工具容器字段
**不支持的功能**
- Embeddings 接口(返回 `INTERFACE_NOT_SUPPORTED` 错误)
### 跨协议转换注意事项
| 源协议 | 目标协议 | 转换说明 |
|--------|----------|----------|
| OpenAI | Anthropic | `reasoning_effort``thinking`,消息角色合并 |
| Anthropic | OpenAI | `thinking` 块 → `reasoning_content`,工具选择转换 |
## 错误码
### ConversionError 错误码
| 错误码 | 说明 |
|--------|------|
| `INVALID_INPUT` | 输入数据无效 |
| `MISSING_REQUIRED_FIELD` | 缺少必填字段 |
| `INCOMPATIBLE_FEATURE` | 功能不兼容(如跨协议不支持某特性) |
| `FIELD_MAPPING_FAILURE` | 字段映射失败 |
| `TOOL_CALL_PARSE_ERROR` | 工具调用解析错误 |
| `JSON_PARSE_ERROR` | JSON 解析错误 |
| `STREAM_STATE_ERROR` | 流式状态错误 |
| `UTF8_DECODE_ERROR` | UTF-8 解码错误(流式 chunk 截断) |
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
| `ENCODING_FAILURE` | 编码失败 |
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings |
### AppError 预定义错误
| 错误 | HTTP 状态码 | 说明 |
|------|-------------|------|
| `ErrModelNotFound` | 404 | 模型未找到 |
| `ErrModelDisabled` | 404 | 模型已禁用 |
| `ErrProviderNotFound` | 404 | 供应商未找到 |
| `ErrInvalidProviderID` | 400 | 供应商 ID 格式无效 |
| `ErrDuplicateModel` | 409 | 同一供应商下模型名称重复 |
| `ErrImmutableField` | 400 | 不可修改字段(如供应商 ID |
## 运行方式
### 安装依赖
@@ -278,10 +395,10 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
#### OpenAI 协议
```
POST /openai/v1/chat/completions
GET /openai/v1/models
POST /openai/v1/embeddings
POST /openai/v1/rerank
POST /openai/chat/completions
GET /openai/models
POST /openai/embeddings
POST /openai/rerank
```
#### Anthropic 协议
@@ -322,7 +439,7 @@ GET /anthropic/v1/models
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
**对外 URL 格式**
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions``/openai/models`
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions``/openai/models``/openai/embeddings`
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages``/anthropic/v1/models`
#### 模型管理

456
backend/cmd/desktop/main.go Normal file
View File

@@ -0,0 +1,456 @@
package main
import (
"context"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/getlantern/systray"
"github.com/pressly/goose/v3"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/handler"
"nex/backend/internal/handler/middleware"
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger"
"nex/embedfs"
)
var (
server *http.Server
zapLogger *zap.Logger
shutdownCtx context.Context
shutdownCancel context.CancelFunc
)
func main() {
port := 9826
if err := acquireSingleInstance(); err != nil {
showError("Nex Gateway", "已有 Nex 实例运行")
os.Exit(1)
}
defer releaseSingleInstance()
if err := checkPortAvailable(port); err != nil {
showError("Nex Gateway", err.Error())
os.Exit(1)
}
cfg, err := config.LoadConfig()
if err != nil {
showError("Nex Gateway", fmt.Sprintf("加载配置失败: %v", err))
os.Exit(1)
}
zapLogger, err = pkgLogger.New(pkgLogger.Config{
Level: cfg.Log.Level,
Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize,
MaxBackups: cfg.Log.MaxBackups,
MaxAge: cfg.Log.MaxAge,
Compress: cfg.Log.Compress,
})
if err != nil {
showError("Nex Gateway", fmt.Sprintf("初始化日志失败: %v", err))
os.Exit(1)
}
defer zapLogger.Sync()
db, err := initDatabase(cfg)
if err != nil {
showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err))
os.Exit(1)
}
defer closeDB(db)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
}
if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
}
engine := conversion.NewConversionEngine(registry, zapLogger)
providerClient := provider.NewClient()
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(middleware.RequestID())
r.Use(middleware.Recovery(zapLogger))
r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
setupStaticFiles(r)
server = &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: r,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
}
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
}
}()
go func() {
time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error()))
}
}()
setupSystray(port)
}
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
dbDir := filepath.Dir(cfg.Database.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
if err != nil {
return nil, err
}
if err := runMigrations(db); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
return db, nil
}
func runMigrations(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir()
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
goose.SetDialect("sqlite3")
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
return nil
}
func getMigrationsDir() string {
_, filename, _, ok := runtime.Caller(0)
if ok {
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
if abs, err := filepath.Abs(dir); err == nil {
return abs
}
}
return "./migrations"
}
func closeDB(db *gorm.DB) {
sqlDB, err := db.DB()
if err != nil {
return
}
sqlDB.Close()
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
r.Any("/v1/*path", proxyHandler.HandleProxy)
providers := r.Group("/api/providers")
{
providers.GET("", providerHandler.ListProviders)
providers.POST("", providerHandler.CreateProvider)
providers.GET("/:id", providerHandler.GetProvider)
providers.PUT("/:id", providerHandler.UpdateProvider)
providers.DELETE("/:id", providerHandler.DeleteProvider)
}
models := r.Group("/api/models")
{
models.GET("", modelHandler.ListModels)
models.POST("", modelHandler.CreateModel)
models.GET("/:id", modelHandler.GetModel)
models.PUT("/:id", modelHandler.UpdateModel)
models.DELETE("/:id", modelHandler.DeleteModel)
}
stats := r.Group("/api/stats")
{
stats.GET("", statsHandler.GetStats)
stats.GET("/aggregate", statsHandler.AggregateStats)
}
r.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})
}
func setupStaticFiles(r *gin.Engine) {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error()))
}
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
}
if strings.HasSuffix(path, ".css") {
return "text/css"
}
if strings.HasSuffix(path, ".svg") {
return "image/svg+xml"
}
if strings.HasSuffix(path, ".png") {
return "image/png"
}
if strings.HasSuffix(path, ".ico") {
return "image/x-icon"
}
if strings.HasSuffix(path, ".woff") || strings.HasSuffix(path, ".woff2") {
return "font/woff2"
}
return "application/octet-stream"
}
r.GET("/assets/*filepath", func(c *gin.Context) {
filepath := c.Param("filepath")
data, err := fs.ReadFile(distFS, "assets"+filepath)
if err != nil {
c.Status(404)
return
}
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
}
data, err := fs.ReadFile(distFS, "index.html")
if err != nil {
c.Status(500)
return
}
c.Data(200, "text/html; charset=utf-8", data)
})
}
func setupSystray(port int) {
systray.Run(func() {
icon, err := embedfs.Assets.ReadFile("assets/icon.png")
if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error()))
}
systray.SetIcon(icon)
systray.SetTitle("Nex Gateway")
systray.SetTooltip("AI Gateway")
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
systray.AddSeparator()
mStatus := systray.AddMenuItem("状态: 运行中", "")
mStatus.Disable()
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
mPort.Disable()
systray.AddSeparator()
mAbout := systray.AddMenuItem("关于", "")
systray.AddSeparator()
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
go func() {
for {
select {
case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port))
case <-mAbout.ClickedCh:
showAbout()
case <-mQuit.ClickedCh:
doShutdown()
systray.Quit()
return
}
}
}()
}, nil)
}
func doShutdown() {
if zapLogger != nil {
zapLogger.Info("正在关闭服务器...")
}
if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.Shutdown(ctx)
}
if shutdownCancel != nil {
shutdownCancel()
}
}
func checkPortAvailable(port int) error {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
}
ln.Close()
return nil
}
var lockFile *os.File
func acquireSingleInstance() error {
lockPath := filepath.Join(os.TempDir(), "nex-gateway.lock")
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
return err
}
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err != nil {
f.Close()
return fmt.Errorf("已有实例运行")
}
lockFile = f
return nil
}
func releaseSingleInstance() {
if lockFile != nil {
syscall.Flock(int(lockFile.Fd()), syscall.LOCK_UN)
lockFile.Close()
}
}
func openBrowser(url string) error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin":
cmd = exec.Command("open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
case "linux":
browsers := []string{"xdg-open", "google-chrome", "firefox"}
for _, browser := range browsers {
if _, err := exec.LookPath(browser); err == nil {
cmd = exec.Command(browser, url)
break
}
}
}
if cmd == nil {
return fmt.Errorf("无法打开浏览器")
}
return cmd.Start()
}
func showError(title, message string) {
switch runtime.GOOS {
case "darwin":
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, message, title)
exec.Command("osascript", "-e", script).Run()
case "windows":
exec.Command("msg", "*", message).Run()
case "linux":
exec.Command("zenity", "--error", fmt.Sprintf("--title=%s", title), fmt.Sprintf("--text=%s", message)).Run()
}
}
func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
switch runtime.GOOS {
case "darwin":
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, message)
exec.Command("osascript", "-e", script).Run()
case "windows":
exec.Command("msg", "*", message).Run()
case "linux":
exec.Command("zenity", "--info", "--title=关于 Nex Gateway", fmt.Sprintf("--text=%s", message)).Run()
}
}

View File

@@ -0,0 +1,69 @@
package main
import (
"net"
"net/http"
"testing"
"time"
)
func TestCheckPortAvailable(t *testing.T) {
port := 19826
err := checkPortAvailable(port)
if err != nil {
t.Fatalf("端口 %d 应该可用: %v", port, err)
}
t.Log("端口可用测试通过")
}
func TestCheckPortOccupied(t *testing.T) {
port := 19827
listener, err := net.Listen("tcp", ":19827")
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
defer listener.Close()
go func() {
conn, err := listener.Accept()
if err == nil {
conn.Close()
}
}()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
if err == nil {
t.Fatal("端口被占用时应该返回错误")
}
t.Log("端口占用检测测试通过")
}
func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828
listener, err := net.Listen("tcp", ":19828")
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
server := &http.Server{}
go server.Serve(listener)
time.Sleep(100 * time.Millisecond)
listener.Close()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
if err != nil {
t.Fatalf("端口关闭后应该可用: %v", err)
}
t.Log("端口关闭后可用测试通过")
}

View File

@@ -0,0 +1,39 @@
package main
import (
"os"
"path/filepath"
"syscall"
"testing"
)
func TestAcquireSingleInstance(t *testing.T) {
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test.lock")
origLockFile := lockFile
lockFile = nil
defer func() { lockFile = origLockFile }()
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
t.Fatalf("无法创建锁文件: %v", err)
}
defer f.Close()
defer os.Remove(lockPath)
err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err != nil {
t.Fatalf("无法获取文件锁: %v", err)
}
defer syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
t.Log("单实例锁测试通过")
}
func TestReleaseSingleInstance(t *testing.T) {
lockFile = nil
releaseSingleInstance()
t.Log("释放空锁测试通过")
}

View File

@@ -0,0 +1,123 @@
package main
import (
"io/fs"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"nex/embedfs"
)
func TestSetupStaticFiles(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
}
if strings.HasSuffix(path, ".css") {
return "text/css"
}
if strings.HasSuffix(path, ".svg") {
return "image/svg+xml"
}
return "application/octet-stream"
}
r := gin.New()
r.GET("/assets/*filepath", func(c *gin.Context) {
filepath := c.Param("filepath")
data, err := fs.ReadFile(distFS, "assets"+filepath)
if err != nil {
c.Status(404)
return
}
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
}
data, err := fs.ReadFile(distFS, "index.html")
if err != nil {
c.Status(500)
return
}
c.Data(200, "text/html; charset=utf-8", data)
})
t.Run("API 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 404 {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
})
t.Run("SPA fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
})
t.Run("MIME type for JS", func(t *testing.T) {
req := httptest.NewRequest("GET", "/assets/test.js", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 200 {
expected := "application/javascript"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
} else {
t.Log("文件不存在,跳过 MIME 类型验证")
}
})
t.Run("MIME type for CSS", func(t *testing.T) {
req := httptest.NewRequest("GET", "/assets/test.css", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 200 {
expected := "text/css"
if w.Header().Get("Content-Type") != expected {
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
}
} else {
t.Log("文件不存在,跳过 MIME 类型验证")
}
})
t.Log("静态文件服务测试通过")
}

View File

@@ -67,13 +67,25 @@ func main() {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
// 5. 初始化 service 层
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
// 5. 初始化缓存
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
if err := routingCache.Preload(); err != nil {
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
}
// 6. 创建 ConversionEngine
// 6. 初始化统计缓冲
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
service.WithFlushInterval(5*time.Second),
service.WithFlushThreshold(100))
statsBuffer.Start()
// 7. 初始化 service 层
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 8. 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
@@ -83,16 +95,16 @@ func main() {
}
engine := conversion.NewConversionEngine(registry, zapLogger)
// 7. 初始化 provider client
// 9. 初始化 provider client
providerClient := provider.NewClient()
// 8. 初始化 handler 层
// 10. 初始化 handler 层
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
// 9. 创建 Gin 引擎
// 11. 创建 Gin 引擎
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -103,7 +115,7 @@ func main() {
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
// 10. 启动服务器
// 12. 启动服务器
srv := &http.Server{
Addr: formatAddr(cfg.Server.Port),
Handler: r,
@@ -131,6 +143,8 @@ func main() {
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
}
statsBuffer.Stop()
zapLogger.Info("服务器已关闭")
}
@@ -142,14 +156,14 @@ func initDatabase(cfg *config.Config) (*gorm.DB, error) {
return nil, err
}
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
if err := runMigrations(db); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, err

View File

@@ -2,7 +2,15 @@ module nex/backend
go 1.26.2
replace nex/embedfs => ../embedfs
tool (
github.com/golangci/golangci-lint/cmd/golangci-lint
go.uber.org/mock/mockgen
)
require (
github.com/getlantern/systray v1.2.2
github.com/gin-gonic/gin v1.12.0
github.com/go-playground/validator/v10 v10.30.2
github.com/google/uuid v1.6.0
@@ -11,60 +19,229 @@ require (
github.com/spf13/pflag v1.0.10
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
go.uber.org/mock v0.6.0
go.uber.org/zap v1.27.1
gopkg.in/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1
nex/embedfs v0.0.0-00010101000000-000000000000
)
require (
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
4d63.com/gochecknoglobals v0.2.2 // indirect
github.com/4meepo/tagalign v1.4.2 // indirect
github.com/Abirdcfly/dupword v0.1.3 // indirect
github.com/Antonboom/errname v1.0.0 // indirect
github.com/Antonboom/nilnil v1.0.1 // indirect
github.com/Antonboom/testifylint v1.5.2 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect
github.com/Crocmagnon/fatcontext v0.7.1 // indirect
github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect
github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
github.com/OpenPeeDeeP/depguard/v2 v2.2.1 // indirect
github.com/alecthomas/go-check-sumtype v0.3.1 // indirect
github.com/alexkohler/nakedret/v2 v2.0.5 // indirect
github.com/alexkohler/prealloc v1.0.0 // indirect
github.com/alingse/asasalint v0.0.11 // indirect
github.com/alingse/nilnesserr v0.1.2 // indirect
github.com/ashanbrown/forbidigo v1.6.0 // indirect
github.com/ashanbrown/makezero v1.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bkielbasa/cyclop v1.2.3 // indirect
github.com/blizzy78/varnamelen v0.8.0 // indirect
github.com/bombsimon/wsl/v4 v4.5.0 // indirect
github.com/breml/bidichk v0.3.2 // indirect
github.com/breml/errchkjson v0.4.0 // indirect
github.com/butuzov/ireturn v0.3.1 // indirect
github.com/butuzov/mirror v1.3.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/catenacyber/perfsprint v0.8.2 // indirect
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charithe/durationcheck v0.0.10 // indirect
github.com/chavacava/garif v0.1.0 // indirect
github.com/ckaznocha/intrange v0.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/curioswitch/go-reassign v0.3.0 // indirect
github.com/daixiang0/gci v0.13.5 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/denis-tingaikin/go-header v0.5.0 // indirect
github.com/ettle/strcase v0.2.0 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/fatih/structtag v1.2.0 // indirect
github.com/firefart/nonamedreturns v1.0.5 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/fzipp/gocyclo v0.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/ghostiam/protogetter v0.3.9 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-critic/go-critic v0.12.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/go-toolsmith/astcast v1.1.0 // indirect
github.com/go-toolsmith/astcopy v1.1.0 // indirect
github.com/go-toolsmith/astequal v1.2.0 // indirect
github.com/go-toolsmith/astfmt v1.1.0 // indirect
github.com/go-toolsmith/astp v1.1.0 // indirect
github.com/go-toolsmith/strparse v1.1.0 // indirect
github.com/go-toolsmith/typep v1.1.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/gofrs/flock v0.12.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect
github.com/golangci/go-printf-func-name v0.1.0 // indirect
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
github.com/golangci/golangci-lint v1.64.8 // indirect
github.com/golangci/misspell v0.6.0 // indirect
github.com/golangci/plugin-module-register v0.1.1 // indirect
github.com/golangci/revgrep v0.8.0 // indirect
github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/gordonklaus/ineffassign v0.1.0 // indirect
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
github.com/gostaticanalysis/comment v1.5.0 // indirect
github.com/gostaticanalysis/forcetypeassert v0.2.0 // indirect
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect
github.com/hashicorp/go-version v1.7.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jgautheron/goconst v1.7.1 // indirect
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jjti/go-spancheck v0.6.4 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/julz/importas v0.2.0 // indirect
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
github.com/kisielk/errcheck v1.9.0 // indirect
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/kulti/thelper v0.6.3 // indirect
github.com/kunwardeep/paralleltest v1.0.10 // indirect
github.com/lasiar/canonicalheader v1.1.2 // indirect
github.com/ldez/exptostd v0.4.2 // indirect
github.com/ldez/gomoddirectives v0.6.1 // indirect
github.com/ldez/grignotin v0.9.0 // indirect
github.com/ldez/tagliatelle v0.7.1 // indirect
github.com/ldez/usetesting v0.4.2 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/leonklingele/grouper v1.1.2 // indirect
github.com/macabu/inamedparam v0.1.3 // indirect
github.com/maratori/testableexamples v1.0.0 // indirect
github.com/maratori/testpackage v1.1.1 // indirect
github.com/matoous/godox v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mgechev/revive v1.7.0 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/moricho/tparallel v0.3.2 // indirect
github.com/nakabonne/nestif v0.3.1 // indirect
github.com/nishanths/exhaustive v0.12.0 // indirect
github.com/nishanths/predeclared v0.2.2 // indirect
github.com/nunnatsa/ginkgolinter v0.19.1 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
github.com/prometheus/client_golang v1.12.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
github.com/quasilyte/gogrep v0.5.0 // indirect
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/raeperd/recvcheck v0.2.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/ryancurrah/gomodguard v1.3.5 // indirect
github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
github.com/securego/gosec/v2 v2.22.2 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sivchari/containedctx v1.0.3 // indirect
github.com/sivchari/tenv v1.12.1 // indirect
github.com/sonatard/noctx v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/sourcegraph/go-diff v0.7.0 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/cobra v1.9.1 // indirect
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tdakkota/asciicheck v0.4.1 // indirect
github.com/tetafro/godot v1.5.0 // indirect
github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect
github.com/timonwong/loggercheck v0.10.1 // indirect
github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect
github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
github.com/ultraware/funlen v0.2.0 // indirect
github.com/ultraware/whitespace v0.2.0 // indirect
github.com/uudashr/gocognit v1.2.0 // indirect
github.com/uudashr/iface v1.3.1 // indirect
github.com/xen0n/gosmopolitan v1.2.2 // indirect
github.com/yagipy/maintidx v1.0.0 // indirect
github.com/yeya24/promlinter v0.3.0 // indirect
github.com/ykadowak/zerologlint v0.1.5 // indirect
gitlab.com/bosi/decorder v0.4.2 // indirect
go-simpler.org/musttag v0.13.0 // indirect
go-simpler.org/sloglint v0.9.0 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.org/x/tools/go/expect v0.1.1-deprecated // indirect
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
honnef.co/go/tools v0.6.1 // indirect
mvdan.cc/gofumpt v0.7.0 // indirect
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
)

File diff suppressed because it is too large Load Diff

View File

@@ -48,11 +48,3 @@ func (UsageStats) TableName() string {
return "usage_stats"
}
// MaskAPIKey 掩码 API Key仅显示最后 4 个字符)
func (p *Provider) MaskAPIKey() {
if len(p.APIKey) > 4 {
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
} else {
p.APIKey = "***"
}
}

View File

@@ -321,3 +321,58 @@ func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEv
}
var _ = json.Marshal
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
return nil, errors.New("decode embedding failed")
}
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"model":"text-embedding","input":"hello"}`)
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestConvertRerankBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
return nil, errors.New("decode rerank failed")
}
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
providerAdapter := newMockAdapter("provider", false)
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
clientAdapter := newMockAdapter("client", false)
providerAdapter := newMockAdapter("provider", false)
_ = engine.RegisterAdapter(clientAdapter)
_ = engine.RegisterAdapter(providerAdapter)
body := []byte(`{"test":"data"}`)
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
require.NoError(t, err)
assert.Equal(t, body, result)
}

View File

@@ -13,18 +13,20 @@ import (
// mockProtocolAdapter 模拟协议适配器
type mockProtocolAdapter struct {
protocolName string
passthrough bool
ifaceType InterfaceType
supportsIface map[InterfaceType]bool
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
streamDecoderFn func() StreamDecoder
streamEncoderFn func() StreamEncoder
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
protocolName string
passthrough bool
ifaceType InterfaceType
supportsIface map[InterfaceType]bool
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
streamDecoderFn func() StreamDecoder
streamEncoderFn func() StreamEncoder
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
}
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
@@ -126,6 +128,9 @@ func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalM
}
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
if m.decodeEmbeddingReqFn != nil {
return m.decodeEmbeddingReqFn(raw)
}
return &canonical.CanonicalEmbeddingRequest{}, nil
}
@@ -142,6 +147,9 @@ func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalE
}
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
if m.decodeRerankReqFn != nil {
return m.decodeRerankReqFn(raw)
}
return &canonical.CanonicalRerankRequest{}, nil
}

View File

@@ -409,3 +409,25 @@ func TestDecodeResponse_Refusal(t *testing.T) {
}
assert.True(t, found)
}
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
body := []byte(`{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "hello"},
{
"role": "assistant",
"content": [{"type": "text", "text": "hello back"}]
}
]
}`)
req, err := decodeRequest(body)
require.NoError(t, err)
assert.Len(t, req.Messages, 2)
assistantMsg := req.Messages[1]
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
assert.Len(t, assistantMsg.Content, 1)
assert.Equal(t, "text", assistantMsg.Content[0].Type)
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
}

View File

@@ -14,11 +14,3 @@ type Provider struct {
UpdatedAt time.Time `json:"updated_at"`
}
// MaskAPIKey 掩码 API Key仅显示最后 4 个字符)
func (p *Provider) MaskAPIKey() {
if len(p.APIKey) > 4 {
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
} else {
p.APIKey = "***"
}
}

View File

@@ -9,12 +9,19 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
)
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
@@ -33,11 +40,16 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
var result domain.Provider
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "p1", result.ID)
assert.Contains(t, result.APIKey, "***")
assert.Equal(t, "sk-test", result.APIKey)
}
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
@@ -56,9 +68,13 @@ func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
}
func TestProviderHandler_UpdateProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "sk-test"}, nil)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"name": "Updated"})
w := httptest.NewRecorder()
@@ -72,7 +88,11 @@ func TestProviderHandler_UpdateProvider(t *testing.T) {
}
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -84,7 +104,12 @@ func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
}
func TestProviderHandler_DeleteProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -97,7 +122,12 @@ func TestProviderHandler_DeleteProvider(t *testing.T) {
}
func TestModelHandler_DeleteModel(t *testing.T) {
h := NewModelHandler(&mockModelService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -110,7 +140,15 @@ func TestModelHandler_DeleteModel(t *testing.T) {
}
func TestModelHandler_CreateModel_Success(t *testing.T) {
h := NewModelHandler(&mockModelService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
model.ID = "mock-uuid-1234"
return nil
})
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "p1",
@@ -130,9 +168,12 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
}
func TestModelHandler_GetModel(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ModelName: "gpt-4"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -148,9 +189,13 @@ func TestModelHandler_GetModel(t *testing.T) {
}
func TestModelHandler_UpdateModel(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
w := httptest.NewRecorder()

View File

@@ -2,119 +2,36 @@ package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
)
func init() {
gin.SetMode(gin.TestMode)
}
// ============ Mock 实现 ============
type mockRoutingService struct {
result *domain.RouteResult
err error
}
func (m *mockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
return m.result, m.err
}
type mockStatsService struct {
err error
stats []domain.UsageStats
aggrResult []map[string]interface{}
}
func (m *mockStatsService) Record(providerID, modelName string) error {
return m.err
}
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return m.stats, nil
}
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
return m.aggrResult
}
type mockProviderService struct {
provider *domain.Provider
providers []domain.Provider
err error
}
func (m *mockProviderService) ListEnabledModels() ([]domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
return nil, nil
}
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
return m.provider, m.err
}
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockProviderService) Delete(id string) error { return m.err }
type mockModelService struct {
model *domain.Model
models []domain.Model
err error
}
func (m *mockModelService) Create(model *domain.Model) error {
if m.err == nil {
model.ID = "mock-uuid-1234"
}
return m.err
}
func (m *mockModelService) Get(id string) (*domain.Model, error) {
return m.model, m.err
}
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
return m.models, m.err
}
func (m *mockModelService) ListEnabled() ([]domain.Model, error) {
return []domain.Model{}, nil
}
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
return m.err
}
func (m *mockModelService) Delete(id string) error { return m.err }
type mockProviderClient struct {
err error
}
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
return nil, m.err
}
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
return nil, m.err
}
// ============ Provider Handler 测试 ============
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
h := NewProviderHandler(&mockProviderService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"id": "p1"})
w := httptest.NewRecorder()
@@ -127,12 +44,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
}
func TestProviderHandler_ListProviders(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
providers: []domain.Provider{
{ID: "p1", Name: "P1"},
{ID: "p2", Name: "P2"},
},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "P1"},
{ID: "p2", Name: "P2"},
}, nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -142,14 +62,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
assert.Equal(t, 200, w.Code)
var result []domain.Provider
json.Unmarshal(w.Body.Bytes(), &result)
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Len(t, result, 2)
}
func TestProviderHandler_GetProvider(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "sk-test"}, nil)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -160,10 +83,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
assert.Equal(t, 200, w.Code)
}
// ============ Model Handler 测试 ============
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
h := NewModelHandler(&mockModelService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{"id": "m1"})
w := httptest.NewRecorder()
@@ -176,12 +101,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
}
func TestModelHandler_ListModels(t *testing.T) {
h := NewModelHandler(&mockModelService{
models: []domain.Model{
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -198,9 +126,12 @@ func TestModelHandler_ListModels(t *testing.T) {
}
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -217,7 +148,15 @@ func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
}
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
model.ID = "mock-uuid-1234"
return nil
})
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
@@ -238,9 +177,13 @@ func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
}
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
h := NewModelHandler(&mockModelService{
model: &domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
w := httptest.NewRecorder()
@@ -257,14 +200,15 @@ func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
}
// ============ Stats Handler 测试 ============
func TestStatsHandler_GetStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
},
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
}, nil)
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -275,7 +219,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
}
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
h := NewStatsHandler(&mockStatsService{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -286,14 +234,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
}
func TestStatsHandler_AggregateStats(t *testing.T) {
h := NewStatsHandler(&mockStatsService{
stats: []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
},
aggrResult: []map[string]interface{}{
{"provider_id": "p1", "request_count": 10},
},
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockStatsService(ctrl)
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
}, nil)
mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
{"provider_id": "p1", "request_count": 10},
})
h := NewStatsHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -303,8 +254,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
assert.Equal(t, 200, w.Code)
}
// ============ writeError 测试 ============
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -333,12 +282,13 @@ func formatMapErrors(errs map[string]string) string {
return "请求验证失败: " + strings.Join(parts, "; ")
}
// ============ 错误类型判断测试 ============
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h := NewProviderHandler(&mockProviderService{
err: appErrors.ErrConflict,
})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
@@ -354,3 +304,158 @@ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
h.CreateProvider(c)
assert.Equal(t, 409, w.Code)
}
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "nonexistent",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
assert.Contains(t, w.Body.String(), "供应商不存在")
}
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 409, w.Code)
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
}
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
h := NewModelHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"provider_id": "openai",
"model_name": "gpt-4",
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 500, w.Code)
}
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 404, w.Code)
}
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 400, w.Code)
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
}
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "p1"}}
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
h.UpdateProvider(c)
assert.Equal(t, 500, w.Code)
}
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockModelService(ctrl)
h := NewModelHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateModel(c)
assert.Equal(t, 400, w.Code)
}
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSvc := mocks.NewMockProviderService(ctrl)
h := NewProviderHandler(mockSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
c.Request.Header.Set("Content-Type", "application/json")
h.CreateProvider(c)
assert.Equal(t, 400, w.Code)
}

View File

@@ -66,7 +66,6 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
return
}
provider.MaskAPIKey()
c.JSON(http.StatusCreated, provider)
}
@@ -85,7 +84,7 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
func (h *ProviderHandler) GetProvider(c *gin.Context) {
id := c.Param("id")
provider, err := h.providerService.Get(id, true)
provider, err := h.providerService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
c.JSON(http.StatusNotFound, gin.H{
@@ -131,7 +130,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
return
}
provider, err := h.providerService.Get(id, true)
provider, err := h.providerService.Get(id)
if err != nil {
writeError(c, err)
return

File diff suppressed because it is too large Load Diff

View File

@@ -50,6 +50,7 @@ type Client struct {
}
// ProviderClient 供应商客户端接口
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)

View File

@@ -2,6 +2,8 @@ package repository
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
// ModelRepository 模型数据仓库接口
type ModelRepository interface {
Create(model *domain.Model) error

View File

@@ -2,6 +2,8 @@ package repository
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
// ProviderRepository 供应商数据仓库接口
type ProviderRepository interface {
Create(provider *domain.Provider) error
@@ -9,7 +11,4 @@ type ProviderRepository interface {
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error
// 统一模型 ID 相关方法
ListEnabledModels() ([]domain.Model, error)
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
}

View File

@@ -71,25 +71,6 @@ func (r *providerRepository) Delete(id string) error {
return nil
}
// ListEnabledModels 返回所有启用的模型(关联启用的供应商)
func (r *providerRepository) ListEnabledModels() ([]domain.Model, error) {
var models []domain.Model
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
Where("models.enabled = ? AND providers.enabled = ?", true, true).
Find(&models).Error
return models, err
}
// FindByProviderAndModelName 按 provider_id 和 model_name 查询模型
func (r *providerRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
var model domain.Model
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&model).Error
if err != nil {
return nil, err
}
return &model, nil
}
func toDomainProvider(p *config.Provider) domain.Provider {
return domain.Provider{
ID: p.ID,

View File

@@ -5,28 +5,16 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"nex/backend/internal/config"
testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
)
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
dir := t.TempDir()
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
// 关闭数据库连接以便 TempDir 清理
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
return testHelpers.SetupTestDB(t)
}
// ============ ProviderRepository 测试 ============
@@ -88,7 +76,7 @@ func TestProviderRepository_Update(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}))
err := repo.Update("p1", map[string]interface{}{"name": "New"})
require.NoError(t, err)
@@ -109,7 +97,7 @@ func TestProviderRepository_Delete(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
err := repo.Delete("p1")
require.NoError(t, err)
@@ -129,17 +117,21 @@ func TestProviderRepository_Delete_NotFound(t *testing.T) {
func TestModelRepository_Create(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, err)
}
func TestModelRepository_GetByID(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
result, err := repo.GetByID("m1")
require.NoError(t, err)
@@ -149,9 +141,11 @@ func TestModelRepository_GetByID(t *testing.T) {
func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
require.NoError(t, err)
@@ -162,9 +156,11 @@ func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
// Wrong provider_id
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
@@ -181,11 +177,14 @@ func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
func TestModelRepository_List(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p2", Name: "Test2", APIKey: "key", BaseURL: "https://test2.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}))
all, err := repo.List("")
require.NoError(t, err)
@@ -246,9 +245,11 @@ func TestModelRepository_ListEnabled(t *testing.T) {
func TestModelRepository_Update(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
err := repo.Update("m1", map[string]interface{}{"enabled": false})
require.NoError(t, err)
@@ -259,9 +260,11 @@ func TestModelRepository_Update(t *testing.T) {
func TestModelRepository_Delete(t *testing.T) {
db := setupTestDB(t)
providerRepo := NewProviderRepository(db)
repo := NewModelRepository(db)
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
err := repo.Delete("m1")
require.NoError(t, err)
@@ -293,10 +296,32 @@ func TestStatsRepository_Query(t *testing.T) {
db := setupTestDB(t)
repo := NewStatsRepository(db)
repo.Record("p1", "gpt-4")
require.NoError(t, repo.Record("p1", "gpt-4"))
// 注意:当前 schema 只有 date 字段有唯一约束
// 所以同一 provider + model 只能有一条记录
stats, err := repo.Query("p1", "", nil, nil)
require.NoError(t, err)
assert.Len(t, stats, 1)
}
func TestModelRepository_List_EmptyResult(t *testing.T) {
db := setupTestDB(t)
repo := NewModelRepository(db)
result, err := repo.List("")
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result)
assert.Len(t, result, 0)
}
func TestProviderRepository_List_EmptyResult(t *testing.T) {
db := setupTestDB(t)
repo := NewProviderRepository(db)
result, err := repo.List()
require.NoError(t, err)
assert.NotNil(t, result)
assert.Empty(t, result)
assert.Len(t, result, 0)
}

View File

@@ -6,8 +6,11 @@ import (
"nex/backend/internal/domain"
)
//go:generate go run go.uber.org/mock/mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
// StatsRepository 统计数据仓库接口
type StatsRepository interface {
Record(providerID, modelName string) error
BatchUpdate(providerID, modelName string, date time.Time, delta int) error
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
}

View File

@@ -43,6 +43,28 @@ func (r *statsRepository) Record(providerID, modelName string) error {
})
}
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, date).First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return tx.Create(&config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}).Error
} else if err != nil {
return err
}
return tx.Model(&stats).
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
})
}
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
var stats []config.UsageStats
query := r.db.Model(&config.UsageStats{})

View File

@@ -2,6 +2,8 @@ package service
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
// ModelService 模型服务接口
type ModelService interface {
Create(model *domain.Model) error

View File

@@ -11,27 +11,30 @@ import (
type modelService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService {
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache}
}
func (s *modelService) Create(model *domain.Model) error {
// 校验供应商存在
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
return appErrors.ErrProviderNotFound
}
// 联合唯一校验:同一供应商下 model_name 不重复
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
return err
}
// 自动生成 UUID 作为 id
model.ID = uuid.New().String()
model.Enabled = true
return s.modelRepo.Create(model)
err := s.modelRepo.Create(model)
if err != nil {
return err
}
s.cache.SetModel(model)
return nil
}
func (s *modelService) Get(id string) (*domain.Model, error) {
@@ -47,20 +50,17 @@ func (s *modelService) ListEnabled() ([]domain.Model, error) {
}
func (s *modelService) Update(id string, updates map[string]interface{}) error {
// 获取当前模型
current, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
// 如果更新 provider_id校验新供应商存在
if providerID, ok := updates["provider_id"].(string); ok {
if _, err := s.providerRepo.GetByID(providerID); err != nil {
return appErrors.ErrProviderNotFound
}
}
// 确定更新后的 provider_id 和 model_name
newProviderID := current.ProviderID
if v, ok := updates["provider_id"].(string); ok {
newProviderID = v
@@ -70,18 +70,37 @@ func (s *modelService) Update(id string, updates map[string]interface{}) error {
newModelName = v
}
// 如果 provider_id 或 model_name 发生变化,校验联合唯一
if newProviderID != current.ProviderID || newModelName != current.ModelName {
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
return err
}
}
return s.modelRepo.Update(id, updates)
err = s.modelRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateModel(current.ProviderID, current.ModelName)
if newProviderID != current.ProviderID || newModelName != current.ModelName {
s.cache.InvalidateModel(newProviderID, newModelName)
}
return nil
}
func (s *modelService) Delete(id string) error {
return s.modelRepo.Delete(id)
model, err := s.modelRepo.GetByID(id)
if err != nil {
return appErrors.ErrModelNotFound
}
err = s.modelRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateModel(model.ProviderID, model.ModelName)
return nil
}
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复

View File

@@ -2,10 +2,12 @@ package service
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
// ProviderService 供应商服务接口
type ProviderService interface {
Create(provider *domain.Provider) error
Get(id string, maskKey bool) (*domain.Provider, error)
Get(id string) (*domain.Provider, error)
List() ([]domain.Provider, error)
Update(id string, updates map[string]interface{}) error
Delete(id string) error

View File

@@ -13,56 +13,56 @@ import (
type providerService struct {
providerRepo repository.ProviderRepository
modelRepo repository.ModelRepository
cache *RoutingCache
}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo}
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService {
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache}
}
func (s *providerService) Create(provider *domain.Provider) error {
// 校验 provider_id 字符集
if err := modelid.ValidateProviderID(provider.ID); err != nil {
return appErrors.ErrInvalidProviderID
}
provider.Enabled = true
err := s.providerRepo.Create(provider)
if err != nil && isUniqueConstraintError(err) {
return appErrors.ErrConflict
if err != nil {
if isUniqueConstraintError(err) {
return appErrors.ErrConflict
}
return err
}
return err
s.cache.SetProvider(provider)
return nil
}
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
provider, err := s.providerRepo.GetByID(id)
if err != nil {
return nil, err
}
if maskKey {
provider.MaskAPIKey()
}
return provider, nil
func (s *providerService) Get(id string) (*domain.Provider, error) {
return s.providerRepo.GetByID(id)
}
func (s *providerService) List() ([]domain.Provider, error) {
providers, err := s.providerRepo.List()
if err != nil {
return nil, err
}
for i := range providers {
providers[i].MaskAPIKey()
}
return providers, nil
return s.providerRepo.List()
}
func (s *providerService) Update(id string, updates map[string]interface{}) error {
if _, ok := updates["id"]; ok {
return appErrors.ErrImmutableField
}
return s.providerRepo.Update(id, updates)
err := s.providerRepo.Update(id, updates)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
func (s *providerService) Delete(id string) error {
return s.providerRepo.Delete(id)
err := s.providerRepo.Delete(id)
if err != nil {
return err
}
s.cache.InvalidateProvider(id)
return nil
}
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)

View File

@@ -0,0 +1,134 @@
package service
import (
"strings"
"sync"
"go.uber.org/zap"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type RoutingCache struct {
providers sync.Map
models sync.Map
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
logger *zap.Logger
}
func NewRoutingCache(
modelRepo repository.ModelRepository,
providerRepo repository.ProviderRepository,
logger *zap.Logger,
) *RoutingCache {
return &RoutingCache{
modelRepo: modelRepo,
providerRepo: providerRepo,
logger: logger,
}
}
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
}
provider, err := c.providerRepo.GetByID(id)
if err != nil {
return nil, err
}
if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil
}
c.providers.Store(id, provider)
return provider, nil
}
func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
}
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil, err
}
if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil
}
c.models.Store(key, model)
return model, nil
}
func (c *RoutingCache) SetProvider(provider *domain.Provider) {
c.providers.Store(provider.ID, provider)
}
func (c *RoutingCache) SetModel(model *domain.Model) {
key := model.ProviderID + "/" + model.ModelName
c.models.Store(key, model)
}
func (c *RoutingCache) InvalidateProvider(id string) {
c.providers.Delete(id)
c.invalidateModelsByProvider(id)
c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id))
}
func (c *RoutingCache) InvalidateModel(providerID, modelName string) {
key := providerID + "/" + modelName
c.models.Delete(key)
c.logger.Debug("Model 缓存失效",
zap.String("provider_id", providerID),
zap.String("model_name", modelName))
}
func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/"
count := 0
c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) {
c.models.Delete(key)
count++
}
return true
})
if count > 0 {
c.logger.Debug("清除 Provider 相关 Model 缓存",
zap.String("provider_id", providerID),
zap.Int("count", count))
}
}
func (c *RoutingCache) Preload() error {
providers, err := c.providerRepo.List()
if err != nil {
return err
}
for i := range providers {
c.providers.Store(providers[i].ID, &providers[i])
}
models, err := c.modelRepo.List("")
if err != nil {
return err
}
for i := range models {
key := models[i].ProviderID + "/" + models[i].ModelName
c.models.Store(key, &models[i])
}
c.logger.Info("缓存预热完成",
zap.Int("providers", len(providers)),
zap.Int("models", len(models)))
return nil
}

View File

@@ -0,0 +1,273 @@
package service
import (
"errors"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockModelRepo struct {
models map[string]*domain.Model
}
func (m *mockModelRepo) Create(model *domain.Model) error {
return nil
}
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
key := providerID + "/" + modelName
if model, ok := m.models[key]; ok {
return model, nil
}
return nil, errors.New("not found")
}
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
var result []domain.Model
for _, model := range m.models {
if providerID == "" || model.ProviderID == providerID {
result = append(result, *model)
}
}
return result, nil
}
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
return nil, nil
}
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockModelRepo) Delete(id string) error {
return nil
}
type mockProviderRepo struct {
providers map[string]*domain.Provider
}
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
return nil
}
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
if provider, ok := m.providers[id]; ok {
return provider, nil
}
return nil, errors.New("not found")
}
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
var result []domain.Provider
for _, provider := range m.providers {
result = append(result, *provider)
}
return result, nil
}
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
return nil
}
func (m *mockProviderRepo) Delete(id string) error {
return nil
}
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
provider, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, "openai", provider.ID)
provider2, err := cache.GetProvider("openai")
require.NoError(t, err)
assert.Equal(t, provider, provider2)
}
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetProvider("notexist")
assert.Error(t, err)
}
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
model, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", model.ModelName)
model2, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
assert.Equal(t, model, model2)
}
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "notexist")
assert.Error(t, err)
}
func TestRoutingCache_DoubleCheck(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := cache.GetProvider("openai")
assert.NoError(t, err)
}()
}
wg.Wait()
}
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
_, err = cache.GetModel("openai", "gpt-3.5")
require.NoError(t, err)
_, err = cache.GetModel("anthropic", "claude")
require.NoError(t, err)
cache.InvalidateProvider("openai")
var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" {
anthropicCount++
}
return true
})
assert.Equal(t, 0, openaiCount)
assert.Equal(t, 1, anthropicCount)
}
func TestRoutingCache_InvalidateModel(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
_, err := cache.GetModel("openai", "gpt-4")
require.NoError(t, err)
cache.InvalidateModel("openai", "gpt-4")
var count int
cache.models.Range(func(key, value interface{}) bool {
count++
return true
})
assert.Equal(t, 0, count)
}
func TestRoutingCache_Preload_Success(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
err := cache.Preload()
require.NoError(t, err)
var providerCount, modelCount int
cache.providers.Range(func(key, value interface{}) bool {
providerCount++
return true
})
cache.models.Range(func(key, value interface{}) bool {
modelCount++
return true
})
assert.Equal(t, 1, providerCount)
assert.Equal(t, 1, modelCount)
}
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
}}
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
}}
cache := NewRoutingCache(modelRepo, providerRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = cache.GetProvider("openai")
_, _ = cache.GetModel("openai", "gpt-4")
cache.InvalidateProvider("openai")
cache.InvalidateModel("openai", "gpt-4")
}()
}
wg.Wait()
}

View File

@@ -2,6 +2,8 @@ package service
import "nex/backend/internal/domain"
//go:generate go run go.uber.org/mock/mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
// RoutingService 路由服务接口
type RoutingService interface {
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)

View File

@@ -4,20 +4,18 @@ import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
)
type routingService struct {
modelRepo repository.ModelRepository
providerRepo repository.ProviderRepository
cache *RoutingCache
}
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
func NewRoutingService(cache *RoutingCache) RoutingService {
return &routingService{cache: cache}
}
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
model, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
model, err := s.cache.GetModel(providerID, modelName)
if err != nil {
return nil, appErrors.ErrModelNotFound
}
@@ -26,7 +24,7 @@ func (s *routingService) RouteByModelName(providerID, modelName string) (*domain
return nil, appErrors.ErrModelDisabled
}
provider, err := s.providerRepo.GetByID(model.ProviderID)
provider, err := s.cache.GetProvider(model.ProviderID)
if err != nil {
return nil, appErrors.ErrProviderNotFound
}

View File

@@ -14,14 +14,15 @@ func TestProviderService_Update(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
err := svc.Update("p1", map[string]interface{}{"name": "Updated"})
require.NoError(t, err)
result, err := svc.Get("p1", false)
result, err := svc.Get("p1")
require.NoError(t, err)
assert.Equal(t, "Updated", result.Name)
}
@@ -30,7 +31,8 @@ func TestProviderService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
assert.Error(t, err)
@@ -40,9 +42,10 @@ func TestModelService_Get(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
@@ -55,9 +58,10 @@ func TestModelService_Update(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
@@ -73,9 +77,10 @@ func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
@@ -87,9 +92,10 @@ func TestModelService_Delete(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
require.NoError(t, svc.Create(model))
@@ -104,7 +110,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Delete("nonexistent")
assert.Error(t, err)
@@ -112,7 +119,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
func TestStatsService_Aggregate_Default(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil)
svc := NewStatsService(statsRepo)
buffer := NewStatsBuffer(statsRepo, nil)
svc := NewStatsService(statsRepo, buffer)
stats := []domain.UsageStats{
{ProviderID: "p1", RequestCount: 10},
@@ -133,7 +141,8 @@ func TestModelService_Update_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
modelRepo := repository.NewModelRepository(db)
providerRepo := repository.NewProviderRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
assert.Error(t, err)

View File

@@ -3,14 +3,16 @@ package service
import (
"errors"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/config"
testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors"
@@ -18,18 +20,14 @@ import (
func setupServiceTestDB(t *testing.T) *gorm.DB {
t.Helper()
dir := t.TempDir()
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
require.NoError(t, err)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
require.NoError(t, err)
t.Cleanup(func() {
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
})
return db
return testHelpers.SetupTestDB(t)
}
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
t.Helper()
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
}
// ============ RoutingService - RouteByModelName 测试 ============
@@ -38,11 +36,11 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
// 创建供应商和模型
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
result, err := svc.RouteByModelName("openai", "gpt-4")
require.NoError(t, err)
@@ -52,9 +50,8 @@ func TestRoutingService_RouteByModelName_Success(t *testing.T) {
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
_, err := svc.RouteByModelName("openai", "nonexistent-model")
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
@@ -64,12 +61,12 @@ func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
// 创建启用的供应商和禁用的模型
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
@@ -79,12 +76,12 @@ func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewRoutingService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewRoutingService(cache)
// 创建启用的供应商和模型,然后禁用供应商
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true})
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true})
providerRepo.Update("openai", map[string]interface{}{"enabled": false})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
_, err := svc.RouteByModelName("openai", "gpt-4")
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
@@ -96,20 +93,19 @@ func TestModelService_Create_GeneratesUUID(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
require.NoError(t, err)
// 验证返回的 model 拥有有效的 UUID
assert.NotEmpty(t, model.ID)
_, err = uuid.Parse(model.ID)
assert.NoError(t, err, "model.ID should be a valid UUID")
// 通过 Get 验证持久化
stored, err := svc.Get(model.ID)
require.NoError(t, err)
assert.Equal(t, model.ID, stored.ID)
@@ -120,15 +116,15 @@ func TestModelService_Create_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
require.NoError(t, err)
// 使用相同的 (providerID, modelName) 创建第二个模型应失败
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err = svc.Create(model2)
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
@@ -138,7 +134,8 @@ func TestModelService_Create_ProviderNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
err := svc.Create(model)
@@ -151,7 +148,8 @@ func TestProviderService_Create_InvalidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -162,7 +160,8 @@ func TestProviderService_Create_ValidID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -177,10 +176,11 @@ func TestModelService_Update_DuplicateModelName(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model1)
@@ -202,7 +202,8 @@ func TestModelService_Update_ModelNotFound(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
err := svc.Update("nonexistent-id", map[string]interface{}{
"model_name": "gpt-4",
@@ -214,9 +215,10 @@ func TestModelService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewModelService(modelRepo, providerRepo)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"})
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
err := svc.Create(model)
@@ -239,7 +241,8 @@ func TestProviderService_Update_ImmutableID(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -256,7 +259,8 @@ func TestProviderService_Update_Success(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
svc := NewProviderService(repo, modelRepo)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
err := svc.Create(provider)
@@ -268,7 +272,233 @@ func TestProviderService_Update_Success(t *testing.T) {
})
require.NoError(t, err)
updated, err := svc.Get("openai", false)
updated, err := svc.Get("openai")
require.NoError(t, err)
assert.Equal(t, "OpenAI Updated", updated.Name)
}
// ============ StatsService - Aggregate ByModel 测试 ============
func TestStatsService_Aggregate_ByModel(t *testing.T) {
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "multiple providers with same model name",
stats: []domain.UsageStats{
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
},
},
{
name: "empty providerID",
stats: []domain.UsageStats{
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
},
expected: []map[string]interface{}{
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
}
// ============ StatsService - Aggregate ByDate 测试 ============
func TestStatsService_Aggregate_ByDate(t *testing.T) {
tests := []struct {
name string
stats []domain.UsageStats
expected []map[string]interface{}
}{
{
name: "normal date grouping",
stats: []domain.UsageStats{
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
},
expected: []map[string]interface{}{
{"date": "2024-01-01", "request_count": 15},
{"date": "2024-01-02", "request_count": 20},
},
},
{
name: "zero-value time",
stats: []domain.UsageStats{
{Date: time.Time{}, RequestCount: 10},
{Date: time.Time{}, RequestCount: 5},
},
expected: []map[string]interface{}{
{"date": "0001-01-01", "request_count": 15},
},
},
{
name: "empty result set",
stats: []domain.UsageStats{},
expected: []map[string]interface{}{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date")
assert.Len(t, result, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, r := range result {
if r["date"] == exp["date"] {
assert.Equal(t, exp["request_count"], r["request_count"])
found = true
break
}
}
assert.True(t, found, "expected result not found: %v", exp)
}
})
}
}
// ============ ProviderService - isUniqueConstraintError 测试 ============
func TestProviderService_isUniqueConstraintError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "UNIQUE constraint failed",
err: errors.New("UNIQUE constraint failed"),
expected: true,
},
{
name: "duplicate key value",
err: errors.New("duplicate key value"),
expected: true,
},
{
name: "UNIQUE constraint case insensitive",
err: errors.New("unique constraint violation"),
expected: true,
},
{
name: "other error",
err: errors.New("some other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isUniqueConstraintError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}
// ============ ProviderService - List API Key 测试 ============
func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
db := setupServiceTestDB(t)
repo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewProviderService(repo, modelRepo, cache)
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
require.NoError(t, svc.Create(provider1))
require.NoError(t, svc.Create(provider2))
providers, err := svc.List()
require.NoError(t, err)
require.Len(t, providers, 2)
expectedKeys := map[string]string{
"openai": "sk-1234567890",
"anthropic": "sk-anthropic1234",
}
for _, p := range providers {
assert.NotContains(t, p.APIKey, "***")
assert.Equal(t, expectedKeys[p.ID], p.APIKey)
}
}
func TestModelService_ConcurrentCreate(t *testing.T) {
db := setupServiceTestDB(t)
providerRepo := repository.NewProviderRepository(db)
modelRepo := repository.NewModelRepository(db)
cache := setupRoutingCache(t, db)
svc := NewModelService(modelRepo, providerRepo, cache)
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
results := make(chan error, 2)
for i := 0; i < 2; i++ {
go func() {
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
results <- svc.Create(model)
}()
}
err1 := <-results
err2 := <-results
successCount := 0
errorCount := 0
for _, err := range []error{err1, err2} {
if err == nil {
successCount++
} else {
errorCount++
}
}
assert.Equal(t, 1, successCount)
assert.Equal(t, 1, errorCount)
}

View File

@@ -0,0 +1,169 @@
package service
import (
"strings"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
"nex/backend/internal/repository"
)
type StatsBuffer struct {
counters sync.Map
flushInterval time.Duration
flushThreshold int
totalCount atomic.Int64
statsRepo repository.StatsRepository
logger *zap.Logger
stopCh chan struct{}
doneCh chan struct{}
}
type StatsBufferOption func(*StatsBuffer)
func WithFlushInterval(d time.Duration) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushInterval = d
}
}
func WithFlushThreshold(threshold int) StatsBufferOption {
return func(b *StatsBuffer) {
b.flushThreshold = threshold
}
}
func NewStatsBuffer(
statsRepo repository.StatsRepository,
logger *zap.Logger,
opts ...StatsBufferOption,
) *StatsBuffer {
b := &StatsBuffer{
statsRepo: statsRepo,
logger: logger,
flushInterval: 5 * time.Second,
flushThreshold: 100,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
for _, opt := range opts {
opt(b)
}
return b
}
func (b *StatsBuffer) Increment(providerID, modelName string) {
today := time.Now().Format("2006-01-02")
key := providerID + "/" + modelName + "/" + today
var counter *int64
if v, ok := b.counters.Load(key); ok {
counter = v.(*int64)
} else {
val := int64(0)
counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded {
counter = actual.(*int64)
}
}
atomic.AddInt64(counter, 1)
if b.totalCount.Add(1) >= int64(b.flushThreshold) {
go b.flush()
}
}
func (b *StatsBuffer) Start() {
go func() {
ticker := time.NewTicker(b.flushInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
b.flush()
case <-b.stopCh:
b.flush()
close(b.doneCh)
return
}
}
}()
}
func (b *StatsBuffer) Stop() {
close(b.stopCh)
<-b.doneCh
}
func (b *StatsBuffer) flush() {
type statEntry struct {
providerID string
modelName string
date string
count int64
}
var entries []statEntry
b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string)
parts := strings.Split(keyStr, "/")
if len(parts) != 3 {
return true
}
counter := value.(*int64)
count := atomic.SwapInt64(counter, 0)
if count > 0 {
entries = append(entries, statEntry{
providerID: parts[0],
modelName: parts[1],
date: parts[2],
count: count,
})
}
return true
})
if len(entries) == 0 {
return
}
success := 0
for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil {
b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.Int64("count", entry.count),
zap.Error(err))
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok {
counter := v.(*int64)
atomic.AddInt64(counter, entry.count)
}
} else {
success++
}
}
b.totalCount.Store(0)
b.logger.Debug("统计刷新完成",
zap.Int("total", len(entries)),
zap.Int("success", success),
zap.Int("failed", len(entries)-success))
}

View File

@@ -0,0 +1,251 @@
package service
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"nex/backend/internal/domain"
)
type mockStatsRepo struct {
records []struct {
providerID string
modelName string
date time.Time
delta int
}
fail bool
mu sync.Mutex
}
func (m *mockStatsRepo) Record(providerID, modelName string) error {
return nil
}
func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.fail {
return errors.New("db error")
}
m.records = append(m.records, struct {
providerID string
modelName string
date time.Time
delta int
}{providerID, modelName, date, delta})
return nil
}
func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
return nil, nil
}
func TestStatsBuffer_Increment(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-3.5")
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count += atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(3), count)
}
func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(100), count)
}
func TestStatsBuffer_LoadOrStore(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
var counterCount int
buffer.counters.Range(func(key, value interface{}) bool {
counterCount++
return true
})
assert.Equal(t, 1, counterCount)
}
func TestStatsBuffer_FlushByInterval(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(100*time.Millisecond))
buffer.Start()
defer buffer.Stop()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
time.Sleep(200 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_FlushByThreshold(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushThreshold(10))
buffer.Start()
defer buffer.Stop()
for i := 0; i < 10; i++ {
buffer.Increment("openai", "gpt-4")
}
time.Sleep(50 * time.Millisecond)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_SwapInt64(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
beforeCount = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(2), beforeCount)
buffer.flush()
var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
afterCount = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(0), afterCount)
}
func TestStatsBuffer_FailRetry(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{fail: true}
buffer := NewStatsBuffer(statsRepo, logger)
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
buffer.flush()
var count int64
buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64)
count = atomic.LoadInt64(counter)
return true
})
assert.Equal(t, int64(2), count)
}
func TestStatsBuffer_Stop(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(10*time.Second))
buffer.Start()
buffer.Increment("openai", "gpt-4")
buffer.Increment("openai", "gpt-4")
start := time.Now()
buffer.Stop()
elapsed := time.Since(start)
assert.Less(t, elapsed, 1*time.Second)
statsRepo.mu.Lock()
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
statsRepo.mu.Unlock()
}
func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) {
logger := zap.NewNop()
statsRepo := &mockStatsRepo{}
buffer := NewStatsBuffer(statsRepo, logger,
WithFlushInterval(50*time.Millisecond))
buffer.Start()
defer buffer.Stop()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
buffer.Increment("openai", "gpt-4")
}()
}
wg.Wait()
time.Sleep(100 * time.Millisecond)
statsRepo.mu.Lock()
totalDelta := 0
for _, r := range statsRepo.records {
totalDelta += r.delta
}
statsRepo.mu.Unlock()
assert.Equal(t, 100, totalDelta)
}

View File

@@ -6,6 +6,8 @@ import (
"nex/backend/internal/domain"
)
//go:generate go run go.uber.org/mock/mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
// StatsService 统计服务接口
type StatsService interface {
Record(providerID, modelName string) error

View File

@@ -10,14 +10,16 @@ import (
type statsService struct {
statsRepo repository.StatsRepository
buffer *StatsBuffer
}
func NewStatsService(statsRepo repository.StatsRepository) StatsService {
return &statsService{statsRepo: statsRepo}
func NewStatsService(statsRepo repository.StatsRepository, buffer *StatsBuffer) StatsService {
return &statsService{statsRepo: statsRepo, buffer: buffer}
}
func (s *statsService) Record(providerID, modelName string) error {
return s.statsRepo.Record(providerID, modelName)
s.buffer.Increment(providerID, modelName)
return nil
}
func (s *statsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {

View File

@@ -0,0 +1,193 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
)
func TestLoadConfig_DefaultValues(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 9826, cfg.Server.Port)
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
assert.Equal(t, "info", cfg.Log.Level)
assert.Equal(t, 100, cfg.Log.MaxSize)
assert.Equal(t, 10, cfg.Log.MaxBackups)
assert.Equal(t, 30, cfg.Log.MaxAge)
assert.True(t, cfg.Log.Compress)
}
func TestLoadConfig_EnvOverride(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
t.Setenv("NEX_SERVER_PORT", "9000")
t.Setenv("NEX_LOG_LEVEL", "debug")
t.Setenv("NEX_DATABASE_MAX_IDLE_CONNS", "20")
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 9000, cfg.Server.Port)
assert.Equal(t, "debug", cfg.Log.Level)
assert.Equal(t, 20, cfg.Database.MaxIdleConns)
}
func TestLoadConfig_YAMLFile(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
yamlContent := `
server:
port: 8080
read_timeout: 60s
write_timeout: 60s
database:
path: /custom/path.db
max_idle_conns: 5
max_open_conns: 50
conn_max_lifetime: 2h
log:
level: warn
path: /custom/log
max_size: 200
max_backups: 5
max_age: 7
compress: false
`
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 8080, cfg.Server.Port)
assert.Equal(t, 60*time.Second, cfg.Server.ReadTimeout)
assert.Equal(t, 60*time.Second, cfg.Server.WriteTimeout)
assert.Equal(t, "/custom/path.db", cfg.Database.Path)
assert.Equal(t, 5, cfg.Database.MaxIdleConns)
assert.Equal(t, 50, cfg.Database.MaxOpenConns)
assert.Equal(t, 2*time.Hour, cfg.Database.ConnMaxLifetime)
assert.Equal(t, "warn", cfg.Log.Level)
assert.Equal(t, "/custom/log", cfg.Log.Path)
assert.Equal(t, 200, cfg.Log.MaxSize)
assert.Equal(t, 5, cfg.Log.MaxBackups)
assert.Equal(t, 7, cfg.Log.MaxAge)
assert.False(t, cfg.Log.Compress)
}
func TestLoadConfig_PriorityChain(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
yamlContent := `
server:
port: 8080
log:
level: warn
`
err := os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
t.Setenv("NEX_SERVER_PORT", "9000")
originalArgs := os.Args
defer func() { os.Args = originalArgs }()
os.Args = []string{"test", "--server-port", "9999"}
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
assert.Equal(t, 9999, cfg.Server.Port, "CLI should override ENV and YAML")
assert.Equal(t, "warn", cfg.Log.Level, "YAML value should be used when no CLI/ENV override")
}
func TestLoadConfig_AutoCreate(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
_, err := os.Stat(configPath)
assert.True(t, os.IsNotExist(err), "config file should not exist before load")
cfg, err := config.LoadConfigFromPath(configPath)
require.NoError(t, err)
require.NotNil(t, cfg)
assert.Equal(t, 9826, cfg.Server.Port, "should load with default values")
}
func TestSaveAndLoadConfig(t *testing.T) {
tmpDir := t.TempDir()
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
nexDir := filepath.Join(homeDir, ".nex")
configPath := filepath.Join(nexDir, "config.yaml")
originalConfig, err := os.ReadFile(configPath)
if err != nil && !os.IsNotExist(err) {
require.NoError(t, err)
}
defer func() {
if originalConfig != nil {
_ = os.WriteFile(configPath, originalConfig, 0644)
}
}()
cfg := &config.Config{
Server: config.ServerConfig{
Port: 7777,
ReadTimeout: 45 * time.Second,
WriteTimeout: 45 * time.Second,
},
Database: config.DatabaseConfig{
Path: filepath.Join(tmpDir, "test.db"),
MaxIdleConns: 15,
MaxOpenConns: 150,
ConnMaxLifetime: 2 * time.Hour,
},
Log: config.LogConfig{
Level: "debug",
Path: filepath.Join(tmpDir, "log"),
MaxSize: 50,
MaxBackups: 3,
MaxAge: 14,
Compress: false,
},
}
err = config.SaveConfig(cfg)
require.NoError(t, err)
loaded, err := config.LoadConfig()
require.NoError(t, err)
assert.Equal(t, cfg.Server.Port, loaded.Server.Port)
assert.Equal(t, cfg.Server.ReadTimeout, loaded.Server.ReadTimeout)
assert.Equal(t, cfg.Server.WriteTimeout, loaded.Server.WriteTimeout)
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
assert.Equal(t, cfg.Database.MaxOpenConns, loaded.Database.MaxOpenConns)
assert.Equal(t, cfg.Database.ConnMaxLifetime, loaded.Database.ConnMaxLifetime)
assert.Equal(t, cfg.Log.Level, loaded.Log.Level)
assert.Equal(t, cfg.Log.MaxSize, loaded.Log.MaxSize)
assert.Equal(t, cfg.Log.MaxBackups, loaded.Log.MaxBackups)
assert.Equal(t, cfg.Log.MaxAge, loaded.Log.MaxAge)
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
}

View File

@@ -7,49 +7,40 @@ import (
"nex/backend/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// SetupTestDB initializes an in-memory SQLite database with auto-migration.
// Uses :memory: mode with MaxOpenConns(1) to ensure all operations share the
// same connection, avoiding "database is closed" errors from connection pool.
// Enables foreign key constraints for SQLite.
func SetupTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:?_foreign_keys=on"), &gorm.Config{})
assert.NoError(t, err, "failed to open test database")
require.NoError(t, err, "failed to open test database")
// 限制为单连接,确保 :memory: 数据库不被连接池丢弃
sqlDB, err := db.DB()
assert.NoError(t, err, "failed to get underlying sql.DB")
require.NoError(t, err, "failed to get underlying sql.DB")
sqlDB.SetMaxOpenConns(1)
sqlDB.SetConnMaxLifetime(0)
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
assert.NoError(t, err, "failed to auto-migrate test database")
require.NoError(t, err, "failed to auto-migrate test database")
return db
}
// CleanupTestDB closes the database after a brief delay to allow async
// goroutines (e.g. stats recording) to finish.
func CleanupTestDB(t *testing.T, db *gorm.DB) {
t.Helper()
// 等待异步 goroutine如 statsService.Record完成
time.Sleep(50 * time.Millisecond)
sqlDB, err := db.DB()
assert.NoError(t, err, "failed to get underlying sql.DB")
require.NoError(t, err, "failed to get underlying sql.DB")
err = sqlDB.Close()
assert.NoError(t, err, "failed to close test database")
require.NoError(t, err, "failed to close test database")
}
// CreateTestProvider creates a test provider and returns it.
func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
t.Helper()
@@ -62,13 +53,11 @@ func CreateTestProvider(t *testing.T, db *gorm.DB, id string) config.Provider {
}
err := db.Create(&provider).Error
assert.NoError(t, err, "failed to create test provider")
require.NoError(t, err, "failed to create test provider")
return provider
}
// CreateTestModel creates a test model and returns it.
// Does NOT assert on error - returns the model and error for caller to verify.
func CreateTestModel(t *testing.T, db *gorm.DB, id string, providerID string, modelName string) (config.Model, error) {
t.Helper()

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/conversion"
@@ -54,10 +55,14 @@ func setupConversionTest(t *testing.T) (*gin.Engine, *gorm.DB, *httptest.Server)
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
// 创建 ConversionEngine
registry := conversion.NewMemoryRegistry()
@@ -533,8 +538,7 @@ func TestConversion_ProviderWithProtocol(t *testing.T) {
var created map[string]any
json.Unmarshal(w.Body.Bytes(), &created)
// API Key 被掩码
assert.Contains(t, created["api_key"], "***")
assert.Equal(t, "sk-test", created["api_key"])
// 获取时应包含 protocol
w = httptest.NewRecorder()

View File

@@ -15,6 +15,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
@@ -48,10 +49,14 @@ func setupE2ETest(t *testing.T) (*gin.Engine, *httptest.Server) {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
routingService := service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
routingService := service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
registry := conversion.NewMemoryRegistry()
require.NoError(t, registry.Register(openaiConv.NewAdapter()))

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"gorm.io/gorm"
"nex/backend/internal/domain"
@@ -30,10 +31,14 @@ func setupIntegrationTest(t *testing.T) (*gin.Engine, *gorm.DB) {
modelRepo := repository.NewModelRepository(db)
statsRepo := repository.NewStatsRepository(db)
providerService := service.NewProviderService(providerRepo, modelRepo)
modelService := service.NewModelService(modelRepo, providerRepo)
_ = service.NewRoutingService(modelRepo, providerRepo)
statsService := service.NewStatsService(statsRepo)
logger := zap.NewNop()
routingCache := service.NewRoutingCache(modelRepo, providerRepo, logger)
statsBuffer := service.NewStatsBuffer(statsRepo, logger)
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
_ = service.NewRoutingService(routingCache)
statsService := service.NewStatsService(statsRepo, statsBuffer)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
@@ -103,7 +108,7 @@ func TestOpenAI_CompleteFlow(t *testing.T) {
var providers []domain.Provider
json.Unmarshal(w.Body.Bytes(), &providers)
assert.Len(t, providers, 1)
assert.Contains(t, providers[0].APIKey, "***") // 已掩码
assert.Equal(t, "sk-test-key", providers[0].APIKey)
// 4. 列出 Model
w = httptest.NewRecorder()

View File

@@ -0,0 +1,143 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: model_repo.go
//
// Generated by this command:
//
// mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockModelRepository is a mock of ModelRepository interface.
type MockModelRepository struct {
ctrl *gomock.Controller
recorder *MockModelRepositoryMockRecorder
isgomock struct{}
}
// MockModelRepositoryMockRecorder is the mock recorder for MockModelRepository.
type MockModelRepositoryMockRecorder struct {
mock *MockModelRepository
}
// NewMockModelRepository creates a new mock instance.
func NewMockModelRepository(ctrl *gomock.Controller) *MockModelRepository {
mock := &MockModelRepository{ctrl: ctrl}
mock.recorder = &MockModelRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockModelRepository) EXPECT() *MockModelRepositoryMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockModelRepository) Create(model *domain.Model) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", model)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockModelRepositoryMockRecorder) Create(model any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockModelRepository)(nil).Create), model)
}
// Delete mocks base method.
func (m *MockModelRepository) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockModelRepositoryMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockModelRepository)(nil).Delete), id)
}
// FindByProviderAndModelName mocks base method.
func (m *MockModelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindByProviderAndModelName", providerID, modelName)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindByProviderAndModelName indicates an expected call of FindByProviderAndModelName.
func (mr *MockModelRepositoryMockRecorder) FindByProviderAndModelName(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByProviderAndModelName", reflect.TypeOf((*MockModelRepository)(nil).FindByProviderAndModelName), providerID, modelName)
}
// GetByID mocks base method.
func (m *MockModelRepository) GetByID(id string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByID", id)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetByID indicates an expected call of GetByID.
func (mr *MockModelRepositoryMockRecorder) GetByID(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockModelRepository)(nil).GetByID), id)
}
// List mocks base method.
func (m *MockModelRepository) List(providerID string) ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", providerID)
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockModelRepositoryMockRecorder) List(providerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockModelRepository)(nil).List), providerID)
}
// ListEnabled mocks base method.
func (m *MockModelRepository) ListEnabled() ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListEnabled")
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListEnabled indicates an expected call of ListEnabled.
func (mr *MockModelRepositoryMockRecorder) ListEnabled() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabled", reflect.TypeOf((*MockModelRepository)(nil).ListEnabled))
}
// Update mocks base method.
func (m *MockModelRepository) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockModelRepositoryMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockModelRepository)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,128 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: model_service.go
//
// Generated by this command:
//
// mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockModelService is a mock of ModelService interface.
type MockModelService struct {
ctrl *gomock.Controller
recorder *MockModelServiceMockRecorder
isgomock struct{}
}
// MockModelServiceMockRecorder is the mock recorder for MockModelService.
type MockModelServiceMockRecorder struct {
mock *MockModelService
}
// NewMockModelService creates a new mock instance.
func NewMockModelService(ctrl *gomock.Controller) *MockModelService {
mock := &MockModelService{ctrl: ctrl}
mock.recorder = &MockModelServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockModelService) EXPECT() *MockModelServiceMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockModelService) Create(model *domain.Model) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", model)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockModelServiceMockRecorder) Create(model any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockModelService)(nil).Create), model)
}
// Delete mocks base method.
func (m *MockModelService) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockModelServiceMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockModelService)(nil).Delete), id)
}
// Get mocks base method.
func (m *MockModelService) Get(id string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", id)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockModelServiceMockRecorder) Get(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockModelService)(nil).Get), id)
}
// List mocks base method.
func (m *MockModelService) List(providerID string) ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", providerID)
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockModelServiceMockRecorder) List(providerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockModelService)(nil).List), providerID)
}
// ListEnabled mocks base method.
func (m *MockModelService) ListEnabled() ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListEnabled")
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListEnabled indicates an expected call of ListEnabled.
func (mr *MockModelServiceMockRecorder) ListEnabled() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabled", reflect.TypeOf((*MockModelService)(nil).ListEnabled))
}
// Update mocks base method.
func (m *MockModelService) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockModelServiceMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockModelService)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,73 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: client.go
//
// Generated by this command:
//
// mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
conversion "nex/backend/internal/conversion"
provider "nex/backend/internal/provider"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockProviderClient is a mock of ProviderClient interface.
type MockProviderClient struct {
ctrl *gomock.Controller
recorder *MockProviderClientMockRecorder
isgomock struct{}
}
// MockProviderClientMockRecorder is the mock recorder for MockProviderClient.
type MockProviderClientMockRecorder struct {
mock *MockProviderClient
}
// NewMockProviderClient creates a new mock instance.
func NewMockProviderClient(ctrl *gomock.Controller) *MockProviderClient {
mock := &MockProviderClient{ctrl: ctrl}
mock.recorder = &MockProviderClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProviderClient) EXPECT() *MockProviderClientMockRecorder {
return m.recorder
}
// Send mocks base method.
func (m *MockProviderClient) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Send", ctx, spec)
ret0, _ := ret[0].(*conversion.HTTPResponseSpec)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Send indicates an expected call of Send.
func (mr *MockProviderClientMockRecorder) Send(ctx, spec any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockProviderClient)(nil).Send), ctx, spec)
}
// SendStream mocks base method.
func (m *MockProviderClient) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendStream", ctx, spec)
ret0, _ := ret[0].(<-chan provider.StreamEvent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendStream indicates an expected call of SendStream.
func (mr *MockProviderClientMockRecorder) SendStream(ctx, spec any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendStream", reflect.TypeOf((*MockProviderClient)(nil).SendStream), ctx, spec)
}

View File

@@ -0,0 +1,113 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: provider_repo.go
//
// Generated by this command:
//
// mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockProviderRepository is a mock of ProviderRepository interface.
type MockProviderRepository struct {
ctrl *gomock.Controller
recorder *MockProviderRepositoryMockRecorder
isgomock struct{}
}
// MockProviderRepositoryMockRecorder is the mock recorder for MockProviderRepository.
type MockProviderRepositoryMockRecorder struct {
mock *MockProviderRepository
}
// NewMockProviderRepository creates a new mock instance.
func NewMockProviderRepository(ctrl *gomock.Controller) *MockProviderRepository {
mock := &MockProviderRepository{ctrl: ctrl}
mock.recorder = &MockProviderRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProviderRepository) EXPECT() *MockProviderRepositoryMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockProviderRepository) Create(provider *domain.Provider) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", provider)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockProviderRepositoryMockRecorder) Create(provider any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockProviderRepository)(nil).Create), provider)
}
// Delete mocks base method.
func (m *MockProviderRepository) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockProviderRepositoryMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProviderRepository)(nil).Delete), id)
}
// GetByID mocks base method.
func (m *MockProviderRepository) GetByID(id string) (*domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByID", id)
ret0, _ := ret[0].(*domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetByID indicates an expected call of GetByID.
func (mr *MockProviderRepositoryMockRecorder) GetByID(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockProviderRepository)(nil).GetByID), id)
}
// List mocks base method.
func (m *MockProviderRepository) List() ([]domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockProviderRepositoryMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockProviderRepository)(nil).List))
}
// Update mocks base method.
func (m *MockProviderRepository) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockProviderRepositoryMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProviderRepository)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,143 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: provider_service.go
//
// Generated by this command:
//
// mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockProviderService is a mock of ProviderService interface.
type MockProviderService struct {
ctrl *gomock.Controller
recorder *MockProviderServiceMockRecorder
isgomock struct{}
}
// MockProviderServiceMockRecorder is the mock recorder for MockProviderService.
type MockProviderServiceMockRecorder struct {
mock *MockProviderService
}
// NewMockProviderService creates a new mock instance.
func NewMockProviderService(ctrl *gomock.Controller) *MockProviderService {
mock := &MockProviderService{ctrl: ctrl}
mock.recorder = &MockProviderServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProviderService) EXPECT() *MockProviderServiceMockRecorder {
return m.recorder
}
// Create mocks base method.
func (m *MockProviderService) Create(provider *domain.Provider) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Create", provider)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create.
func (mr *MockProviderServiceMockRecorder) Create(provider any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockProviderService)(nil).Create), provider)
}
// Delete mocks base method.
func (m *MockProviderService) Delete(id string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", id)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockProviderServiceMockRecorder) Delete(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockProviderService)(nil).Delete), id)
}
// Get mocks base method.
func (m *MockProviderService) Get(id string) (*domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", id)
ret0, _ := ret[0].(*domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockProviderServiceMockRecorder) Get(id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockProviderService)(nil).Get), id)
}
// GetModelByProviderAndName mocks base method.
func (m *MockProviderService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetModelByProviderAndName", providerID, modelName)
ret0, _ := ret[0].(*domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetModelByProviderAndName indicates an expected call of GetModelByProviderAndName.
func (mr *MockProviderServiceMockRecorder) GetModelByProviderAndName(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetModelByProviderAndName", reflect.TypeOf((*MockProviderService)(nil).GetModelByProviderAndName), providerID, modelName)
}
// List mocks base method.
func (m *MockProviderService) List() ([]domain.Provider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]domain.Provider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockProviderServiceMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockProviderService)(nil).List))
}
// ListEnabledModels mocks base method.
func (m *MockProviderService) ListEnabledModels() ([]domain.Model, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListEnabledModels")
ret0, _ := ret[0].([]domain.Model)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListEnabledModels indicates an expected call of ListEnabledModels.
func (mr *MockProviderServiceMockRecorder) ListEnabledModels() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEnabledModels", reflect.TypeOf((*MockProviderService)(nil).ListEnabledModels))
}
// Update mocks base method.
func (m *MockProviderService) Update(id string, updates map[string]any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", id, updates)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockProviderServiceMockRecorder) Update(id, updates any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockProviderService)(nil).Update), id, updates)
}

View File

@@ -0,0 +1,56 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: routing_service.go
//
// Generated by this command:
//
// mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockRoutingService is a mock of RoutingService interface.
type MockRoutingService struct {
ctrl *gomock.Controller
recorder *MockRoutingServiceMockRecorder
isgomock struct{}
}
// MockRoutingServiceMockRecorder is the mock recorder for MockRoutingService.
type MockRoutingServiceMockRecorder struct {
mock *MockRoutingService
}
// NewMockRoutingService creates a new mock instance.
func NewMockRoutingService(ctrl *gomock.Controller) *MockRoutingService {
mock := &MockRoutingService{ctrl: ctrl}
mock.recorder = &MockRoutingServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRoutingService) EXPECT() *MockRoutingServiceMockRecorder {
return m.recorder
}
// RouteByModelName mocks base method.
func (m *MockRoutingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteByModelName", providerID, modelName)
ret0, _ := ret[0].(*domain.RouteResult)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RouteByModelName indicates an expected call of RouteByModelName.
func (mr *MockRoutingServiceMockRecorder) RouteByModelName(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteByModelName", reflect.TypeOf((*MockRoutingService)(nil).RouteByModelName), providerID, modelName)
}

View File

@@ -0,0 +1,85 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: stats_repo.go
//
// Generated by this command:
//
// mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
time "time"
gomock "go.uber.org/mock/gomock"
)
// MockStatsRepository is a mock of StatsRepository interface.
type MockStatsRepository struct {
ctrl *gomock.Controller
recorder *MockStatsRepositoryMockRecorder
isgomock struct{}
}
// MockStatsRepositoryMockRecorder is the mock recorder for MockStatsRepository.
type MockStatsRepositoryMockRecorder struct {
mock *MockStatsRepository
}
// NewMockStatsRepository creates a new mock instance.
func NewMockStatsRepository(ctrl *gomock.Controller) *MockStatsRepository {
mock := &MockStatsRepository{ctrl: ctrl}
mock.recorder = &MockStatsRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStatsRepository) EXPECT() *MockStatsRepositoryMockRecorder {
return m.recorder
}
// BatchUpdate mocks base method.
func (m *MockStatsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BatchUpdate", providerID, modelName, date, delta)
ret0, _ := ret[0].(error)
return ret0
}
// BatchUpdate indicates an expected call of BatchUpdate.
func (mr *MockStatsRepositoryMockRecorder) BatchUpdate(providerID, modelName, date, delta any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchUpdate", reflect.TypeOf((*MockStatsRepository)(nil).BatchUpdate), providerID, modelName, date, delta)
}
// Query mocks base method.
func (m *MockStatsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Query", providerID, modelName, startDate, endDate)
ret0, _ := ret[0].([]domain.UsageStats)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStatsRepositoryMockRecorder) Query(providerID, modelName, startDate, endDate any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStatsRepository)(nil).Query), providerID, modelName, startDate, endDate)
}
// Record mocks base method.
func (m *MockStatsRepository) Record(providerID, modelName string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Record", providerID, modelName)
ret0, _ := ret[0].(error)
return ret0
}
// Record indicates an expected call of Record.
func (mr *MockStatsRepositoryMockRecorder) Record(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Record", reflect.TypeOf((*MockStatsRepository)(nil).Record), providerID, modelName)
}

View File

@@ -0,0 +1,85 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: stats_service.go
//
// Generated by this command:
//
// mockgen -source=stats_service.go -destination=../../tests/mocks/mock_stats_service.go -package=mocks
//
// Package mocks is a generated GoMock package.
package mocks
import (
domain "nex/backend/internal/domain"
reflect "reflect"
time "time"
gomock "go.uber.org/mock/gomock"
)
// MockStatsService is a mock of StatsService interface.
type MockStatsService struct {
ctrl *gomock.Controller
recorder *MockStatsServiceMockRecorder
isgomock struct{}
}
// MockStatsServiceMockRecorder is the mock recorder for MockStatsService.
type MockStatsServiceMockRecorder struct {
mock *MockStatsService
}
// NewMockStatsService creates a new mock instance.
func NewMockStatsService(ctrl *gomock.Controller) *MockStatsService {
mock := &MockStatsService{ctrl: ctrl}
mock.recorder = &MockStatsServiceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStatsService) EXPECT() *MockStatsServiceMockRecorder {
return m.recorder
}
// Aggregate mocks base method.
func (m *MockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]any {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Aggregate", stats, groupBy)
ret0, _ := ret[0].([]map[string]any)
return ret0
}
// Aggregate indicates an expected call of Aggregate.
func (mr *MockStatsServiceMockRecorder) Aggregate(stats, groupBy any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockStatsService)(nil).Aggregate), stats, groupBy)
}
// Get mocks base method.
func (m *MockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", providerID, modelName, startDate, endDate)
ret0, _ := ret[0].([]domain.UsageStats)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockStatsServiceMockRecorder) Get(providerID, modelName, startDate, endDate any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStatsService)(nil).Get), providerID, modelName, startDate, endDate)
}
// Record mocks base method.
func (m *MockStatsService) Record(providerID, modelName string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Record", providerID, modelName)
ret0, _ := ret[0].(error)
return ret0
}
// Record indicates an expected call of Record.
func (mr *MockStatsServiceMockRecorder) Record(providerID, modelName any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Record", reflect.TypeOf((*MockStatsService)(nil).Record), providerID, modelName)
}

View File

@@ -126,24 +126,54 @@
### 2.3 请求处理流程
每个 HTTP 请求的转换流程:
#### 2.3.1 三车道数据流模型
引擎根据协议匹配情况和参数条件,选择三条数据流通道之一:
```
客户端入站 调用方(协议识别+前缀剥离) SDK 内部处理 上游出站
┌──────────────────┐ ┌──────────────────┐
│ URL: │ 调用方完成: 1. 接口识别: CHAT │ URL: │
/<protocol>/ │ · clientProtocol 2. 同协议? ──yes──▶ 直接转发│ 目标协议
│ v1/... │ · nativePath └──no──▶ 继续转换 │ 原生路径 │
Headers: │ · providerProtocol 3. URL 映射: 目标路径 │ Headers:
协议原生格式 4. Header 构建: 目标格式 │ 目标协议格式
Body:5. Body 转换: D→C→E │ Body:
协议原生格式 │ 目标协议格式
└──────────────────┘ └──────────────────┘
┌──────────────────────────────────────────────────────────────────────────────────┐
数据流三车道模型
├────────────────┬─────────────────────────┬────────────────────────────────────────┤
透传车道 │ 智能透传车道 完整转换车道
├────────────────┼─────────────────────────┼────────────────────────────────────────┤
触发条件 │ 同协议 │ 同协议 + 接口∈{Chat,Embed,Rerank} │ 不同协议
│ + Body非空 + ModelName非空 │
请求处理 │ 重建Headers │ 重建HeadersDecode→Middleware→Encode
│ Body原样转发 │ RewriteRequestModelName(body) │
│ │ │ 最小化JSON字段手术 │ │
│ │ │ │ │
│ 响应处理 │ 原样返回 │ modelOverride非空时 │ Decode→modelOverride │
│ │ │ RewriteResponseModelName(body) │ →Encode │
│ │ │ │ │
│ 流式处理 │ chunk→[chunk] │ chunk→RewriteResponseModelName │ Decode→Middleware │
│ │ │ →[rewritten] │ →modelOverride→Encode │
│ │ │ │ │
│ 性能开销 │ 最低 │ 低仅JSON字段改写 │ 高(完整序列化) │
└────────────────┴─────────────────────────┴────────────────────────────────────────┘
```
响应方向同理含流式。D=Decoder, C=Canonical, E=Encoder
**智能透传的设计动机**:同协议场景下,若仅需改写 `model` 字段(如客户端请求模型 "X",上游需要模型 "Y"),无需完整解码/编码。直接在 JSON 层面手术式改写该字段,既保留原始请求的所有细节,又避免序列化开销
#### 2.3.2 完整请求处理流程
```
客户端入站 调用方职责 SDK 内部处理 上游出站
┌──────────────┐ ┌──────────────┐
│ URL: │ 调用方完成: 1. 接口识别 │ URL: │
│ /<protocol>/ │ · clientProtocol 2. IsPassthrough? │ 目标协议 │
│ v1/... │ · nativePath ├─ yes ─┬─ 无 modelOverride → 透传车道 │ 原生路径 │
│ Headers: │ · providerProtocol │ └─ 有 modelOverride → 智能透传车道│ Headers: │
│ 协议原生格式 │ │ │ 目标协议格式 │
│ Body: │ └─ no → 完整转换车道 │ Body: │
│ 协议原生格式 │ │ 目标协议格式 │
└──────────────┘ └──────────────┘
```
响应方向同理(含流式)。
**同协议透传**client == provider 时,仅重建 Header 后原样转发到上游。
**智能透传**:同协议且需改写 model 字段时,最小化 JSON 改写后转发。
**未知接口透传**无法识别的路径URL+Header 适配后 Body 原样转发。
---
@@ -229,13 +259,15 @@ ContentBlock = Union<
content: Union<String, Array<ContentBlock>>,
is_error: Option<Boolean> }
ThinkingBlock, { type: "thinking", thinking: String }
ImageBlock, { type: "image", source: ... } // 多模态预留
AudioBlock, { type: "audio", source: ... } // 多模态预留
VideoBlock, { type: "video", source: ... } // 多模态预留
FileBlock { type: "file", source: ... } // 多模态预留
ImageBlock, { type: "image", source: ... } // Deferred多模态预留
AudioBlock, { type: "audio", source: ... } // Deferred多模态预留
VideoBlock, { type: "video", source: ... } // Deferred多模态预留
FileBlock { type: "file", source: ... } // Deferred多模态预留
>
```
**当前实现状态**:仅实现了 TextBlock、ToolUseBlock、ToolResultBlock、ThinkingBlock。多模态类型ImageBlock、AudioBlock、VideoBlock、FileBlock为预留扩展点。
### 4.4 CanonicalTool / ToolChoice
```
@@ -400,7 +432,7 @@ interface ProtocolAdapter {
createStreamEncoder(): StreamEncoder
// 错误编码
encodeError(error: ConversionError): RawResponse
encodeError(error: ConversionError): (body, statusCode)
// 扩展层
decodeModelsResponse(raw): CanonicalModelList
@@ -415,20 +447,40 @@ interface ProtocolAdapter {
encodeRerankRequest(canonical, provider): RawRequest
decodeRerankResponse(raw): CanonicalRerankResponse
encodeRerankResponse(canonical): RawResponse
// 智能透传支持
extractUnifiedModelID(nativePath: String): (modelID, error) // 从路径提取统一模型 ID
extractModelName(body: Raw, interfaceType: InterfaceType): (modelName, error) // 从请求体提取 model 字段值
rewriteRequestModelName(body: Raw, newModel: String, interfaceType: InterfaceType): RawRequest // 最小化改写请求体中的 model 字段
rewriteResponseModelName(body: Raw, newModel: String, interfaceType: InterfaceType): RawResponse // 最小化改写响应体中的 model 字段
}
```
**`buildHeaders` 的设计**Adapter 只需从 `provider` 中提取自己协议需要的认证和配置信息,构建自己的 Header 格式。不再需要理解其他协议的 Header。
**智能透传方法的契约**
- `rewriteRequestModelName` / `rewriteResponseModelName` 必须**幂等**(多次调用结果相同)
- Rewrite 方法必须**最小化**(仅修改 model 字段,不触碰其他字段)
- Rewrite 失败时,引擎使用宽容策略:记录警告日志,使用原始 body 继续处理
- `extractModelName` 支持的接口类型CHAT、EMBEDDINGS、RERANK这些接口的请求体包含 model 字段)
### 5.3 InterfaceType
```
InterfaceType = Enum<
CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK,
AUDIO, IMAGES // 预留:多模态扩展时启用
CHAT, MODELS, MODEL_INFO, EMBEDDINGS, RERANK, PASSTHROUGH
>
```
| 类型 | 说明 |
|------|------|
| CHAT | 核心对话接口(各协议的 Chat/Messages 接口) |
| MODELS | 模型列表接口 |
| MODEL_INFO | 模型详情接口 |
| EMBEDDINGS | 向量嵌入接口 |
| RERANK | 重排序接口 |
| PASSTHROUGH | 未知接口,透传处理 |
### 5.4 StreamDecoder / StreamEncoder
```
@@ -463,68 +515,138 @@ ConversionEngine 是无状态的格式转换工具,仅做协议间的编解码
**协议识别**`clientProtocol``providerProtocol` 由调用方确定并传入引擎(详见 §2.2)。
#### 6.1.1 HTTP 规格
```
HTTPRequestSpec {
url: String // 请求路径(不含 base_url
method: String // HTTP 方法
headers: Map<String, String>
body: ByteArray // 原始请求体
}
HTTPResponseSpec {
statusCode: Integer
headers: Map<String, String>
body: ByteArray // 原始响应体
}
```
#### 6.1.2 引擎接口
```
class ConversionEngine {
registry: AdapterRegistry
middlewareChain: MiddlewareChain
// 生命周期
registerAdapter(adapter): void
use(middleware): void
getRegistry(): AdapterRegistry
// 核心转换
isPassthrough(clientProtocol, providerProtocol): Boolean {
return clientProtocol == providerProtocol && registry.get(clientProtocol).supportsPassthrough()
}
// 非流式请求转换
convertHttpRequest(request, clientProtocol, providerProtocol, provider): HttpRequest {
nativePath = request.url
convertHttpRequest(request, clientProtocol, providerProtocol, provider): HTTPRequestSpec
convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType, modelOverride): HTTPResponseSpec
createStreamConverter(clientProtocol, providerProtocol, modelOverride, interfaceType): StreamConverter
if isPassthrough(clientProtocol, providerProtocol):
providerAdapter = registry.get(providerProtocol)
return {url: provider.base_url + nativePath, method: request.method,
headers: providerAdapter.buildHeaders(provider), body: request.body}
clientAdapter = registry.get(clientProtocol)
providerAdapter = registry.get(providerProtocol)
// 使用 clientAdapter 识别接口类型
interfaceType = clientAdapter.detectInterfaceType(nativePath)
providerUrl = providerAdapter.buildUrl(nativePath, interfaceType)
providerHeaders = providerAdapter.buildHeaders(provider)
providerBody = convertBody(interfaceType, clientAdapter, providerAdapter, provider, request.body)
return {url: provider.base_url + providerUrl, method: request.method,
headers: providerHeaders, body: providerBody}
}
// 非流式响应转换
convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType): HttpResponse {
if isPassthrough(clientProtocol, providerProtocol): return response
clientAdapter = registry.get(clientProtocol)
providerAdapter = registry.get(providerProtocol)
providerBody = convertResponseBody(interfaceType, clientAdapter, providerAdapter, response.body)
return {status: response.status, headers: response.headers, body: providerBody}
}
// 流式转换:从 provider 协议解码,编码为 client 协议
createStreamConverter(clientProtocol, providerProtocol, provider): StreamConverter {
if isPassthrough(clientProtocol, providerProtocol):
providerAdapter = registry.get(providerProtocol)
return new PassthroughStreamConverter(providerAdapter.buildHeaders(provider))
providerAdapter = registry.get(providerProtocol)
clientAdapter = registry.get(clientProtocol)
return new CanonicalStreamConverter(
providerAdapter.createStreamDecoder(), clientAdapter.createStreamEncoder(), middlewareChain)
}
// 辅助方法
detectInterfaceType(nativePath, clientProtocol): InterfaceType
encodeError(error, clientProtocol): (body, statusCode)
}
```
#### 6.1.3 请求转换流程
```
function convertHttpRequest(request, clientProtocol, providerProtocol, provider):
nativePath = request.url
if isPassthrough(clientProtocol, providerProtocol):
providerAdapter = registry.get(providerProtocol)
interfaceType = providerAdapter.detectInterfaceType(nativePath)
// 智能透传:对 Chat/Embeddings/Rerank 接口改写 model 字段
if interfaceType in {CHAT, EMBEDDINGS, RERANK} and request.body非空 and provider.modelName非空:
rewrittenBody = providerAdapter.rewriteRequestModelName(request.body, provider.modelName, interfaceType)
if rewrite失败:
log.warn("智能透传改写失败,使用原始请求体")
rewrittenBody = request.body
else:
rewrittenBody = request.body
return HTTPRequestSpec {
url: provider.base_url + nativePath,
method: request.method,
headers: providerAdapter.buildHeaders(provider),
body: rewrittenBody
}
// 完整转换车道
clientAdapter = registry.get(clientProtocol)
providerAdapter = registry.get(providerProtocol)
interfaceType = clientAdapter.detectInterfaceType(nativePath)
providerUrl = providerAdapter.buildUrl(nativePath, interfaceType)
providerHeaders = providerAdapter.buildHeaders(provider)
providerBody = convertBody(interfaceType, clientAdapter, providerAdapter, provider, request.body)
return HTTPRequestSpec {
url: provider.base_url + providerUrl,
method: request.method,
headers: providerHeaders,
body: providerBody
}
```
#### 6.1.4 响应转换流程
```
function convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType, modelOverride):
if isPassthrough(clientProtocol, providerProtocol):
// 智能透传:改写响应体中的 model 字段
if modelOverride非空 and response.body非空:
adapter = registry.get(clientProtocol)
rewrittenBody = adapter.rewriteResponseModelName(response.body, modelOverride, interfaceType)
if rewrite失败:
log.warn("智能透传改写失败,使用原始响应体")
return response
return HTTPResponseSpec { statusCode: response.statusCode, headers: response.headers, body: rewrittenBody }
return response
// 完整转换车道
clientAdapter = registry.get(clientProtocol)
providerAdapter = registry.get(providerProtocol)
convertedBody = convertResponseBody(interfaceType, clientAdapter, providerAdapter, response.body, modelOverride)
return HTTPResponseSpec { statusCode: response.statusCode, headers: response.headers, body: convertedBody }
```
**modelOverride 参数语义**:跨协议场景下,客户端期望看到的模型名。在响应转换时直接覆写 `canonicalResponse.model = modelOverride`,在流式转换时覆写 `event.message.model = modelOverride`
### 6.2 Body 转换分发
#### 6.2.1 接口类型分发策略
| 接口类型 | 请求转换 | 响应转换 | 中间件 | modelOverride |
|---------|---------|---------|--------|---------------|
| CHAT | Decode→**Middleware**→Encode | Decode→modelOverride→Encode | **应用** | 支持(响应) |
| MODELS | Body透传GET | Decode→Encode | 不应用 | 不支持 |
| MODEL_INFO | Body透传GET | Decode→Encode | 不应用 | 不支持 |
| EMBEDDINGS | Decode→Encode | Decode→modelOverride→Encode | 不应用 | 支持(响应) |
| RERANK | Decode→Encode | Decode→modelOverride→Encode | 不应用 | 支持(响应) |
| PASSTHROUGH | Body透传 | Body透传 | 不应用 | 不支持 |
**关键说明**
- **只有 CHAT 接口**走完整的 `decode → middleware → encode` 管道
- **扩展层接口**EMBEDDINGS、RERANK**跳过中间件**,直接 decode → encode
- **扩展层响应支持 modelOverride**,在 encode 前直接覆写 canonical 字段
#### 6.2.2 请求体转换
```
function convertBody(interfaceType, clientAdapter, providerAdapter, provider, body):
switch interfaceType:
@@ -546,41 +668,146 @@ function convertBody(interfaceType, clientAdapter, providerAdapter, provider, bo
// 同 EMBEDDINGS 模式
default:
return body // 透传层:原样转发
```
function convertResponseBody(interfaceType, clientAdapter, providerAdapter, body):
// 结构与 convertBody 对称CHAT 走 Canonical 深度转换,扩展层走轻量映射,默认透传
// 各接口的具体响应转换逻辑详见各协议适配文档(附录 E
#### 6.2.3 响应体转换
```
function convertResponseBody(interfaceType, clientAdapter, providerAdapter, body, modelOverride):
switch interfaceType:
case CHAT:
canonical = providerAdapter.decodeResponse(body)
if modelOverride非空: canonical.model = modelOverride
return clientAdapter.encodeResponse(canonical)
case MODELS:
if !clientAdapter.supportsInterface(MODELS) || !providerAdapter.supportsInterface(MODELS):
return body
return clientAdapter.encodeModelsResponse(providerAdapter.decodeModelsResponse(body))
case MODEL_INFO:
// 同 MODELS 模式
case EMBEDDINGS:
if !clientAdapter.supportsInterface(EMBEDDINGS) || !providerAdapter.supportsInterface(EMBEDDINGS):
return body
canonical = providerAdapter.decodeEmbeddingResponse(body)
if modelOverride非空: canonical.model = modelOverride
return clientAdapter.encodeEmbeddingResponse(canonical)
case RERANK:
// 同 EMBEDDINGS 模式,支持 modelOverride
default:
return body // 透传层:原样转发
```
### 6.3 StreamConverter
#### 6.3.1 接口定义
```
interface StreamConverter {
processChunk(rawChunk): Array<RawSSEChunk>
flush(): Array<RawSSEChunk>
}
```
#### 6.3.2 三种转换器变体
| 转换器 | 触发条件 | processChunk | flush |
|--------|---------|--------------|-------|
| `PassthroughStreamConverter` | 同协议 + 无 modelOverride | `[rawChunk]` | `[]` |
| `SmartPassthroughStreamConverter` | 同协议 + 有 modelOverride | `[rewriteResponseModelName(rawChunk)]` | `[]` |
| `CanonicalStreamConverter` | 不同协议 | Decode→Middleware→modelOverride→Encode | decoder.flush()→encoder.flush() |
#### 6.3.3 PassthroughStreamConverter
```
class PassthroughStreamConverter implements StreamConverter {
headers: Map<String, String>
constructor(headers) { this.headers = headers }
processChunk(rawChunk): Array<RawSSEChunk> { return [rawChunk] }
flush(): Array<RawSSEChunk> { return [] }
}
```
#### 6.3.4 SmartPassthroughStreamConverter
```
class SmartPassthroughStreamConverter implements StreamConverter {
adapter: ProtocolAdapter
modelOverride: String
interfaceType: InterfaceType
processChunk(rawChunk): Array<RawSSEChunk> {
if rawChunk为空: return []
rewrittenChunk = adapter.rewriteResponseModelName(rawChunk, modelOverride, interfaceType)
if rewrite失败:
log.warn("智能透传改写失败,使用原始 chunk")
return [rawChunk]
return [rewrittenChunk]
}
flush(): Array<RawSSEChunk> { return [] }
}
```
#### 6.3.5 CanonicalStreamConverter
```
class CanonicalStreamConverter implements StreamConverter {
decoder: StreamDecoder
encoder: StreamEncoder
middleware: MiddlewareChain
context: ConversionContext
clientProtocol: String
providerProtocol: String
modelOverride: String
processChunk(rawChunk):
events = decoder.processChunk(rawChunk).map(e => middleware.applyStreamEvent(e))
return events.flatMap(e => encoder.encodeEvent(e))
events = decoder.processChunk(rawChunk)
result = []
for each event in events:
// 中间件:转换 canonical 事件
if middleware != null:
processed, err = middleware.applyStreamEvent(event, clientProtocol, providerProtocol, context)
if err != null:
continue // 宽容策略:跳过错误事件,继续处理
event = processed
// modelOverride覆写 model 字段
if modelOverride非空 and event.message != null:
event.message.model = modelOverride
// 编码为目标协议 SSE
chunks = encoder.encodeEvent(event)
result.append(chunks)
return result
flush():
return decoder.flush().flatMap(e => encoder.encodeEvent(e)) + encoder.flush()
events = decoder.flush()
// 同 processChunk 的中间件 + modelOverride + encode 管道
result = [经过管道处理的所有事件]
encoderChunks = encoder.flush()
result.append(encoderChunks)
return result
}
```
#### 6.3.6 流式转换器创建
```
function createStreamConverter(clientProtocol, providerProtocol, modelOverride, interfaceType):
if isPassthrough(clientProtocol, providerProtocol):
if modelOverride非空:
adapter = registry.get(clientProtocol)
return new SmartPassthroughStreamConverter(adapter, modelOverride, interfaceType)
return new PassthroughStreamConverter()
providerAdapter = registry.get(providerProtocol)
clientAdapter = registry.get(clientProtocol)
return new CanonicalStreamConverter(
decoder: providerAdapter.createStreamDecoder(),
encoder: clientAdapter.createStreamEncoder(),
middleware: middlewareChain,
modelOverride: modelOverride
)
```
### 6.4 Middleware
引擎内部的拦截钩子,在 decode → encode 之间对 Canonical 进行变换。
@@ -588,48 +815,73 @@ class CanonicalStreamConverter implements StreamConverter {
```
interface ConversionMiddleware {
intercept(canonical, clientProtocol, providerProtocol, context): canonical | error
interceptStreamEvent?(event, clientProtocol, providerProtocol, context): event | error
interceptStreamEvent(event, clientProtocol, providerProtocol, context): event | error
}
ConversionContext { conversionId, interfaceType, timestamp, metadata }
ConversionContext {
conversionId: String // 唯一转换 IDUUID
interfaceType: InterfaceType
timestamp: DateTime
metadata: Map<String, Any>
}
```
#### 6.4.1 中间件执行规则
- `intercept` 返回修改后的 canonical或返回 ConversionError 以**中断转换**
- `interceptStreamEvent` 同理,返回错误可中断流式转换
- `interceptStreamEvent` 返回修改后的 event或返回 error
- 多个 Middleware 按注册顺序链式执行,任一中断则后续不再执行
#### 6.4.2 错误处理差异
| 场景 | 错误处理策略 |
|------|-------------|
| 请求中间件 `intercept` 返回 error | **严格模式**:中断整个转换,返回错误 |
| 流式中间件 `interceptStreamEvent` 返回 error | **宽容模式**:跳过该事件,继续处理后续事件 |
**设计动机**:流式场景下,单个事件的错误不应中断整个流。请求场景下,错误请求应被明确拒绝。
### 6.5 使用示例
```
engine = new ConversionEngine()
engine.registerAdapter(new ProtocolAAdapter())
engine.registerAdapter(new ProtocolBAdapter())
engine.registerAdapter(new OpenAIAdapter())
engine.registerAdapter(new AnthropicAdapter())
// 场景1: 跨协议 Chat 转换
// 入站: /protocol_a/v1/chat/completions
// 入站: /openai/v1/chat/completions
provider = TargetProvider {
base_url: "https://api.protocol-b.com",
base_url: "https://api.anthropic.com",
api_key: "xxx",
model_name: "model-b",
adapter_config: { ... }
model_name: "claude-3-opus",
adapter_config: { anthropic_version: "2023-06-01" }
}
out = engine.convertHttpRequest(inRequest, "protocol_a", "protocol_b", provider)
// 出站: 目标协议路径 + 目标协议 headers + 转换后的 body
out = engine.convertHttpRequest(inRequest, "openai", "anthropic", provider)
// 出站: /v1/messages + Anthropic headers + 转换后的 body
// 场景2: /models 跨协议
out = engine.convertHttpRequest(inRequest, "protocol_a", "protocol_b", provider)
// URL: /v1/models(通常不变), headers 按目标协议格式重建
out = engine.convertHttpRequest(inRequest, "openai", "anthropic", provider)
// URL: /v1/models, headers 按 Anthropic 格式重建
// 场景3: 同协议透传
out = engine.convertHttpRequest(inRequest, "protocol_a", "protocol_a", provider)
// client == provider → 剥离前缀, 用 provider 重建 headers 后原样转发
out = engine.convertHttpRequest(inRequest, "openai", "openai", provider)
// client == provider → 透传车道:重建 headers 后原样转发
// 场景4: 流式转换(从 provider 协议解码,编码为 client 协议
converter = engine.createStreamConverter("protocol_a", "protocol_b", provider)
// 场景4: 同协议智能透传(改写 model 字段
provider = TargetProvider { model_name: "gpt-4-turbo", ... }
out = engine.convertHttpRequest(inRequest, "openai", "openai", provider)
// 智能透传车道:请求体中的 model 字段改写为 "gpt-4-turbo"
// 场景5: 流式转换(从 provider 协议解码,编码为 client 协议)
converter = engine.createStreamConverter("openai", "anthropic", "claude-3-opus", CHAT)
for chunk in upstreamSSE {
for out in converter.processChunk(chunk) { sendToClient(out) }
}
converter.flush()
// 场景6: 同协议流式智能透传
converter = engine.createStreamConverter("openai", "openai", "gpt-4-turbo", CHAT)
// 使用 SmartPassthroughStreamConverter逐 chunk 改写 model 字段
```
---
@@ -641,9 +893,13 @@ converter.flush()
```
上游 SSE 流
├─ 同协议: PassthroughStreamConverter(用 provider 重建 Headers 后逐块转发)
├─ 同协议 + 无 modelOverride: PassthroughStreamConverter
│ chunk → [chunk]
└─协议: CanonicalStreamConverter
协议 + 有 modelOverride: SmartPassthroughStreamConverter
│ chunk → [rewriteResponseModelName(chunk)]
└─ 不同协议: CanonicalStreamConverter
StreamDecoder StreamEncoder
┌───────────┐ ┌───────────┐
│ SSE Parser│ │SSE Writer │
@@ -653,9 +909,17 @@ converter.flush()
│ Event │──────────────────────▶│ Event │
│ Translator │ ┌──────────┐ │ Translator │
│ (状态机) │ │Middleware│ │ │
└───────────┘ └──────────┘ └───────────┘
└───────────┘ │(宽容模式) │ └───────────┘
└──────────┘
┌─────▼─────┐
│modelOverride│
│(覆写 model) │
└───────────┘
```
**流式中间件错误处理**`interceptStreamEvent` 返回 error 时,跳过该事件继续处理后续事件(宽容模式),而非中断整个流。
### 7.2 StreamDecoder 通用状态
StreamDecoder 需要跟踪以下通用状态。具体协议的 Decoder 可根据需要扩展:
@@ -697,12 +961,16 @@ StreamDecoder 将协议原生 SSE 事件翻译为 CanonicalStreamEventStreamE
### 8.2 多模态扩展
**状态**Deferred未实现
Canonical Model 已预留 ImageBlock / AudioBlock / VideoBlock / FileBlock。实现路径
1. 在各 ProtocolAdapter 中实现多模态块的编解码
2. 在 StreamDecoder/StreamEncoder 中处理多模态增量数据
### 8.3 有状态特性扩展
**状态**Deferred未实现
```
interface StatefulMiddleware extends ConversionMiddleware {
stateStore: SessionStateStore
@@ -722,6 +990,8 @@ interface StatefulMiddleware extends ConversionMiddleware {
### 8.5 自定义接口支持
**状态**Deferred未实现
```
interface CustomInterfaceHandler {
interfaceType(): InterfaceType
@@ -732,6 +1002,8 @@ interface CustomInterfaceHandler {
engine.registerCustomHandler(handler)
```
当前实现中,未知接口直接走 PASSTHROUGH 透传。
---
## 9. 错误处理
@@ -759,16 +1031,36 @@ ErrorCode = Enum<
### 9.2 错误处理策略
```
ErrorHandler { mode: "strict" | "lenient" }
引擎采用**分层宽容策略**,根据接口层级和场景选择不同的错误处理方式:
strict: 任何错误抛出异常
lenient: 尽力继续
INCOMPATIBLE_FEATURE → 降级继续
INTERFACE_NOT_SUPPORTED → 透传或返回空响应
TOOL_CALL_PARSE_ERROR → 保留原始内容继续
PROTOCOL_CONSTRAINT_VIOLATION → 自动修复
```
┌─────────────────────────────────────────────────────────────────────┐
│ 错误处理分层策略 │
├─────────────────┬───────────────────────────────────────────────────┤
│ 核心层CHAT │ 严格模式:任何错误返回 ConversionError中断转换 │
│ │ - Decode 失败 → 返回 JSON_PARSE_ERROR │
│ │ - Middleware 失败 → 返回错误 │
│ │ - Encode 失败 → 返回 ENCODING_FAILURE │
├─────────────────┼───────────────────────────────────────────────────┤
│ 扩展层 │ 宽容模式:记录警告日志,返回原始 body 透传 │
│ (Models/Embed/ │ - Decode 失败 → log.warn + 返回原始 body │
│ Rerank) │ - Encode 失败 → log.warn + 返回原始 body │
├─────────────────┼───────────────────────────────────────────────────┤
│ 流式中间件 │ 宽容模式:跳过错误事件,继续处理后续事件 │
│ │ - interceptStreamEvent 返回 error → continue │
├─────────────────┼───────────────────────────────────────────────────┤
│ 智能透传 │ 宽容模式:重写失败则使用原始 body/chunk │
│ │ - Rewrite 失败 → log.warn + 返回原始 body/chunk │
├─────────────────┼───────────────────────────────────────────────────┤
│ 请求中间件 │ 严格模式:返回错误则中断整个转换 │
│ │ - intercept 返回 error → 返回 error │
└─────────────────┴───────────────────────────────────────────────────┘
```
**设计动机**
- 核心接口CHAT必须保证语义正确性错误应明确暴露
- 扩展层接口优先保证可用性,错误应降级处理
- 流式场景不应因单个事件错误中断整个流
**不支持接口的处理**`INTERFACE_NOT_SUPPORTED`
@@ -778,7 +1070,7 @@ lenient: 尽力继续
| 返回空响应 | 不影响核心功能 | 返回空列表 `{data: []}` |
| 返回错误 | 客户端明确需要此功能 | 返回 501 或协议格式错误 |
具体策略通过配置或 Middleware 决定
具体策略`supportsInterface` 返回值决定:返回 false 时引擎直接透传 body
### 9.3 错误响应格式
@@ -786,6 +1078,24 @@ lenient: 尽力继续
Middleware 中断转换时同理,引擎调用 clientAdapter.encodeError 将 ConversionError 编码为客户端可理解的格式。
#### 9.3.1 EncodeError Fallback 行为
当客户端适配器不可用时,引擎使用通用 JSON 错误作为 fallback
```
function encodeError(error, clientProtocol):
adapter = registry.get(clientProtocol)
if adapter不存在:
// Fallback: 通用 JSON 错误
return {
"error": {
"message": error.message,
"type": "internal_error"
}
}, 500
return adapter.encodeError(error)
```
---
## 附录 A模块依赖
@@ -796,21 +1106,25 @@ Middleware 中断转换时同理,引擎调用 clientAdapter.encodeError 将 Co
│ 门面HTTP 转换 / 透传判断 / 流式转换 │
│ 无状态;协议识别见 §2.2 │
├──────────────────────────────────────────────────┤
│ HTTPRequestSpec / HTTPResponseSpec │
│ url, method, headers, body / statusCode, ... │
├──────────────────────────────────────────────────┤
│ TargetProvider │
│ base_url / api_key / model_name / adapter_config │
├──────────────────┬───────────────────────────────┤
│ AdapterRegistry │ MiddlewareChain │
├──────────────────┴───────────────────────────────┤
│ StreamConverter: Passthrough | Canonical
│ StreamConverter: Passthrough | SmartPassthrough | Canonical │
├──────────────────────────────────────────────────┤
│ ProtocolAdapter: 各协议实现 │
│ · buildHeaders(provider) · URL 映射 │
│ · Chat/Models/ModelInfo/Embeddings/Rerank/... 编解码 │
│ · encodeError · StreamDecoder / StreamEncoder │
│ · rewriteRequestModelName / rewriteResponseModelName (智能透传) │
├──────────────────────────────────────────────────┤
│ Canonical Model (Core + Extended) │
├──────────────────────────────────────────────────┤
│ Error Handling
│ Error Handling (分层宽容策略)
├──────────────────────────────────────────────────┤
│ Utility: UTF-8 Buffer / SSE Parser / Detector │
└──────────────────────────────────────────────────┘
@@ -823,22 +1137,25 @@ Middleware 中断转换时同理,引擎调用 clientAdapter.encodeError 将 Co
```
// ─── 核心入口 ───
ConversionEngine
.registerAdapter(adapter)
.use(middleware)
.registerAdapter(adapter): void
.use(middleware): void
.getRegistry(): AdapterRegistry
.isPassthrough(clientProtocol, providerProtocol): Boolean
.convertHttpRequest(request, clientProtocol, providerProtocol, provider): HttpRequest
.convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType): HttpResponse
.createStreamConverter(clientProtocol, providerProtocol, provider): StreamConverter
.convertHttpRequest(request, clientProtocol, providerProtocol, provider): HTTPRequestSpec
.convertHttpResponse(response, clientProtocol, providerProtocol, interfaceType, modelOverride): HTTPResponseSpec
.createStreamConverter(clientProtocol, providerProtocol, modelOverride, interfaceType): StreamConverter
.detectInterfaceType(nativePath, clientProtocol): InterfaceType
.encodeError(error, clientProtocol): (body, statusCode)
// ─── HTTP 规格 ───
HTTPRequestSpec { url, method, headers, body }
HTTPResponseSpec { statusCode, headers, body }
// ─── 目标上游信息 ───
TargetProvider
.base_url: String
.api_key: String
.model_name: String
.adapter_config: Map<String, Any>
TargetProvider { base_url, api_key, model_name, adapter_config }
// ─── URL 路由 ───
// 协议识别见 §2.2;出站: provider.base_url + 目标协议原生路径
// ─── 接口类型 ───
InterfaceType = CHAT | MODELS | MODEL_INFO | EMBEDDINGS | RERANK | PASSTHROUGH
// ─── 协议适配器 ───
ProtocolAdapter
@@ -847,19 +1164,30 @@ ProtocolAdapter
.decodeRequest(raw) / .encodeRequest(canonical, provider)
.decodeResponse(raw) / .encodeResponse(canonical)
.createStreamDecoder() / .createStreamEncoder()
.encodeError(error): RawResponse
.encodeError(error): (body, statusCode)
// 扩展层
.decodeModelsResponse / .encodeModelsResponse
.decodeModelInfoResponse / .encodeModelInfoResponse
.decodeEmbeddingRequest / .encodeEmbeddingRequest(canonical, provider) / ...Response
.decodeRerankRequest / .encodeRerankRequest(canonical, provider) / ...Response
.decodeEmbeddingRequest / .encodeEmbeddingRequest / ...Response
.decodeRerankRequest / .encodeRerankRequest / ...Response
// 智能透传支持
.extractUnifiedModelID(nativePath)
.extractModelName(body, interfaceType)
.rewriteRequestModelName(body, newModel, interfaceType)
.rewriteResponseModelName(body, newModel, interfaceType)
// ─── 流式处理 ───
StreamConverter: .processChunk(raw) / .flush()
├─ PassthroughStreamConverter [raw] → [raw](用 provider 重建 Headers
CanonicalStreamConverter decode → middleware → encode
├─ PassthroughStreamConverter [raw] → [raw]
SmartPassthroughStreamConverter [raw] → [rewrite(raw)]
└─ CanonicalStreamConverter decode → middleware → modelOverride → encode
// ─── 接口类型 ───
InterfaceType = CHAT | MODELS | MODEL_INFO | EMBEDDINGS | RERANK | AUDIO | IMAGES
// ─── 中间件 ───
ConversionMiddleware
.intercept(req, clientProtocol, providerProtocol, ctx): (req, error)
.interceptStreamEvent(event, clientProtocol, providerProtocol, ctx): (event, error)
ConversionContext { conversionId, interfaceType, timestamp, metadata }
```
---
@@ -1070,6 +1398,28 @@ Canonical Model 是**活的公共契约**,不是固定不变的。其字段集
| D.6 | [ ] 流式 StreamDecoder 和 StreamEncoder 已实现(对照 §4.8 |
| D.7 | [ ] 扩展层接口的编解码已实现(支持的接口) |
| D.8 | [ ] `encodeError` 已实现 |
| D.10 | [ ] `extractUnifiedModelID(nativePath)` 已实现 |
| D.10 | [ ] `extractModelName(body, interfaceType)` 已实现(覆盖 Chat/Embeddings/Rerank |
| D.10 | [ ] `rewriteRequestModelName` 已实现(幂等、最小化) |
| D.10 | [ ] `rewriteResponseModelName` 已实现(按接口类型处理 model 字段存在性) |
### D.10 智能透传支持
| 项目 | 说明 |
|------|------|
| extractUnifiedModelID | 从路径提取统一模型 ID 的规则(如 `/models/{provider_id}/{model_name}` |
| extractModelName | 从请求体提取 model 字段值的规则按接口类型Chat/Embeddings/Rerank |
| rewriteRequestModelName | 请求体 model 字段改写规则(最小化 JSON 手术,仅修改 model 字段) |
| rewriteResponseModelName | 响应体 model 字段改写规则(按接口类型处理 model 字段存在性) |
**rewriteResponseModelName 的接口类型差异**
| 接口类型 | model 字段处理 |
|---------|---------------|
| CHAT | 存在则改写,不存在则添加(协议要求必须有 model 字段) |
| EMBEDDINGS | 存在则改写,不存在则添加(协议要求必须有 model 字段) |
| RERANK | 存在则改写不存在则不添加model 字段可选) |
| 其他 | 直接返回原始 body |
---

61
docs/prompts/README.md Normal file
View File

@@ -0,0 +1,61 @@
# Prompts
面向 AI 大模型的提示词集合,每份提示词独立可用,完整复制给 AI 工具即可启动对应流程。
## 命名规则
文件名格式:`prompt-{action}.md``{action}` 使用明确无歧义的英文单词或短语,用连字符连接,例如 `prompt-smart-merge.md``prompt-spec-review.md`
## 提示词
| 文件 | 用途 |
| ---- | ---- |
| [prompt-smart-merge.md](prompt-smart-merge.md) | 批量合并 dev 分支到主干,含依赖分析、冲突处理、安全回退 |
| [prompt-spec-review.md](prompt-spec-review.md) | 审查和重构 openspec/specs/ 下的规范文件 |
| [prompt-proposal-review.md](prompt-proposal-review.md) | 审查 openspec 变更文档与讨论内容的一致性 |
| [prompt-apply-review.md](prompt-apply-review.md) | 审查代码实现与 openspec 变更文档的一致性 |
## 书写原则
### 面向 AI 而非人类
- 不写背景知识、适用场景、定期审查节奏等"补充说明"AI 只需要执行指令
- 不写解释性注释(为什么用 merge 不用 rebase直接作为约束声明
- 不写示例输出模板AI 自行推断格式
### 结构
```
一句话描述任务目标
## 约束
全局不可违反的规则,顶部声明,不在步骤中重复
## 1. 收集/准备
## 2. 分析
## 3. 计划(用户确认)
## 4. 执行(用户确认)
## 5. 清理/收尾
```
编号步骤,不用"第X步"(省 token。步骤之间有确认节点的明确标注。
### 精简
- 每句话只含一条指令,不嵌套子句
- 表格优于列表,列表优于段落
- 规则只声明一次,不在多处重复
- 单文件控制在 150 行以内
### 安全
- 破坏性操作(删除、重写、推送、合并提交)执行前必须用提问工具获得用户确认
- 提供回退机制安全锚点、备份标记、abort 路径)
- 危险命令的约束直接写在约束块中,不用"严禁""务必"等修饰词,用"禁止"即可
- 信息展示分层渐进(概览 → 详情 → 原始数据),避免一次输出过多内容
### 可操作性
- 给出具体命令或工具调用方式,不抽象描述("分析分支"→ "git diff --name-status target...branch"
- 标注并行/串行:可并行的步骤明确写"并行",有副作用的操作标注"串行"
-`{占位符}` 标记需要 AI 替换的参数

View File

@@ -0,0 +1,71 @@
审查 openspec apply 完成后的实现是否与变更文档一致,按以下流程执行。
## 约束
- 仅审查代码和文档,不修改源码
- 每批修改建议执行前用提问工具获得用户确认
- 涉及删除/重写操作前必须备份原文件
## 1. 收集
并行读取:
- 本次变更涉及的所有文档proposal.md、design.md、tasks.md、specs/*.md
- 实际变更的源码文件
- 测试文件和测试结果
- openspec/config.yaml
## 2. 分析
根据上下文判断 apply 后是否有手动改动,将实现与文档双向对照:
| 维度 | 检查点 |
| ---- | ------ |
| 目标覆盖 | 代码实现是否覆盖 proposal 中的所有目标;是否有遗漏的功能点 |
| 方案一致性 | 实现是否与 design 中的技术方案一致;是否有偏离设计的地方 |
| 规范遵循 | 代码是否遵循 specs 中的规范要求;是否有违反 SHALL 约束的地方 |
| 任务完成度 | tasks 中每项任务是否真正完成;是否有未完成但标记完成的任务 |
| 测试完整性 | 测试是否覆盖所有场景;是否有跳过、取消、降低难度的测试;测试是否真正验证了功能 |
| 代码质量 | 是否有明显的代码问题(重复、复杂度过高、命名不清等);是否有可优化的地方 |
重点识别:
- 文档要求但未实现的功能 → 需补充代码
- 实现与文档描述不一致的地方 → 根据上下文判断apply 后手动改动则更新文档,否则确认修正方向
- 实现了但文档未提及的功能 → 根据上下文判断apply 后手动改动则补充文档,否则标记为未讨论的新增
- 标记完成但实际未完成的任务
- 掩盖错误的测试skip、only、降低断言强度等
输出审查结果:
1. **问题总览表**:问题类型 × 涉及文件数
2. **逐项分析**:每个问题文件,说明问题、影响和建议
3. **未覆盖清单**:哪些文档要求未在代码中实现(需补充代码)
4. **不一致清单**:哪些实现与文档描述不一致(需确认修正方向)
5. **需补充文档清单**:哪些代码改动未在文档中体现(需补充文档)
6. **任务问题清单**:哪些任务未真正完成或标记错误
7. **测试问题清单**:哪些测试存在问题或掩盖了错误
8. **优化建议清单**:哪些代码可以优化
若所有清单均为空,输出"审查通过,未发现问题",跳至步骤 5。
## 3. 计划(用户确认)
针对发现的问题,分类提出修复方案:
**需补充代码**:文档要求但未实现的功能,建议补充实现。
**需补充文档**:代码已实现但文档未记录,且上下文确认为 apply 后手动改动,建议更新 proposal/design/tasks/specs。
**需确认修正方向**:实现与文档不一致,且上下文无法判断:
- 若为 apply 后手动改动 → 更新文档
- 否则用提问工具确认以文档还是代码为准
**任务和测试问题**:逐项说明未完成原因、测试掩盖错误的风险,提出修复方案。
用提问工具展示完整修复方案,获得用户确认后执行。
## 4. 执行
逐项执行修复方案。
## 5. 收尾
列出所有修改的文件和变更摘要。

View File

@@ -0,0 +1,56 @@
审查 openspec 变更文档proposal、design、tasks、specs是否完整准确地记录技术方案按以下流程执行。
## 约束
- 仅审查文档内容,不修改源码
- 每批修改建议执行前用提问工具获得用户确认
- 涉及删除/重写操作前必须备份原文件
## 1. 收集
并行读取:
- 本次变更涉及的所有文档proposal.md、design.md、tasks.md、specs/*.md
- 之前上下文中讨论的内容
- openspec/config.yaml
- 与变更规范有依赖关系的其他规范
## 2. 分析
将文档与讨论内容逐项对照,检查:
| 文档 | 检查点 |
| ---- | ------ |
| proposal.md | 是否完整记录讨论确定的目标、范围、影响;是否遗漏决策点 |
| design.md | 是否覆盖讨论中所有技术方案;边界条件和异常处理是否与讨论一致 |
| tasks.md | 是否覆盖 design 中所有方案;任务划分是否合理;依赖关系是否明确 |
| specs/*.md | 是否严格遵循 OpenSpec 格式;术语是否一致;依赖声明是否完整;无实现细节混入 |
重点识别:
- 讨论中确定但文档未记录的内容
- 文档描述与讨论不一致的地方
- 文档新增但未在讨论中提及的内容
输出审查结果:
1. **问题总览表**:问题类型 × 涉及文档数
2. **逐项分析**:每个问题文档,说明问题、影响和建议
3. **缺失清单**:哪些技术方案/需求未在文档中体现
4. **冲突清单**:哪些描述与其他文档或讨论结果不一致
5. **待澄清清单**:哪些事项描述不明确,需要进一步确认
若所有清单均为空,输出"审查通过,未发现问题",跳至步骤 5。
## 3. 计划(用户确认)
针对发现的问题,提出修复方案(补充遗漏、修正冲突、优化表述)。
针对待澄清清单,用提问工具逐项向用户确认,根据反馈更新文档。
用提问工具展示完整修复方案,获得用户确认后执行。
## 4. 执行
逐项执行修复方案。
## 5. 收尾
列出所有修改的文件和变更摘要。

View File

@@ -0,0 +1,120 @@
请对当前项目中所有 `dev*` 分支进行智能合并到目标分支(默认 main按以下流程执行。
## 约束(全局,不可违反)
- 所有操作(合并、删除)执行前必须用提问工具获得用户确认
- 冲突文件严禁自主编辑,仅分析方案后交用户选择
- 全程仅使用 `git merge`,禁止 rebaserebase 会重写目标分支历史)
- `git add` 仅指定已解决冲突的文件路径,禁止 `git add .`/`git add -A`
- `git reset --hard` 仅配合安全锚点 tag 使用,禁止裸用
- 禁止自动 `git stash` `git push`
## 1. 环境检查
- `git status` 确认工作区干净,不干净则提示用户处理
- 确认目标分支,拉取最新:`git pull`
- 列出 dev 分支:`git branch --list 'dev*'`,无则结束
- 创建全局安全锚点:`git tag pre-merge-backup-{timestamp}`,报告标签名
## 2. 分支分析
对每个 dev 分支并行收集:
| 维度 | 内容 |
| ------ | ---------------------------------------------------------------------------------------- |
| 基础 | 分支名、分叉 commit、commit 数/消息、是否推远端、是否已合并(`git branch --merged` |
| 变更 | 文件列表(`git diff --name-status target...branch`)、所属模块、行数统计 |
| 依赖 | 是否依赖/被依赖其他 dev 分支、是否修改公共文件(共享类型、工具函数、配置) |
| 冲突 | dry-run 预测(逐个串行,因需修改工作区):`git merge --no-commit --no-ff branch` → 收集冲突 → `git merge --abort`;与其他 dev 分支文件重叠 |
## 3. 合并顺序
按以下优先级排序:已合并(跳过) → 公共/基础设施变更 → 独立模块 → 有依赖的 → 高冲突/跨模块。
输出计划表(分支名、模块、文件数、依赖、预估冲突、风险),用提问工具让用户确认,用户可调整顺序或排除。
## 4. 逐个合并
对每个分支重复以下流程:
### 准备
1. 确认工作区干净、当前在目标分支
2. `git tag merge-before-{分支名}` 创建分支级安全锚点
3. 向用户确认即将合并的分支及风险
### 执行
`git merge {分支} --no-ff`
- 无冲突 → 进入验证
- 有冲突 → 进入冲突处理
### 冲突处理(三层渐进)
`git diff --name-only --diff-filter=U` 列出冲突文件,然后按以下三层逐步展开:
**第一层:冲突概览表**
向用户展示所有冲突文件的摘要,每文件一行:
| 文件 | 冲突区域 | 冲突类型 | 目标分支改动 | 合并分支改动 | 推荐方案 |
| ---- | -------- | -------- | ------------ | ------------ | -------- |
- 冲突类型:双方修改同一区域 / 一方删除一方修改 / 文件重命名冲突等
- 改动描述:精简到关键差异(如"新增 rateLimit 字段"而非展示原文)
- 推荐方案:根据分析给出最合理的选项
用提问工具让用户选择处理方式:
- 批量处理:对推荐方案无异议的文件一键确认
- 逐个审查:用户指定要详细审查的文件,进入第二层
- 放弃合并:`git merge --abort`,跳过当前分支,继续下一个
**第二层:单个文件详情(按需)**
仅展示用户指定审查的文件:
- 冲突区域的上下文(前后几行非冲突代码)
- HEAD 侧与分支侧的具体差异(精简 diff非完整文件内容
- 方案选项:双保留 / 保留目标(--ours) / 保留分支(--theirs) / 用户编辑 / 放弃合并
若用户仍觉信息不足,进入第三层。
**第三层:原始冲突标记(按需)**
展示用户指定文件的完整 `<<<<<<<`/`=======`/`>>>>>>>` 原始标记内容。
**确认与提交**
1. 方案全部确定后逐个执行:
- 双保留AI 生成合并后的文件内容,展示给用户确认后才写入,严禁未经确认直接写入
- 保留目标/分支:`git checkout --ours/--theirs {file}`
- 用户编辑:等待用户编辑完成后 `git add {file}`
2. 逐个 add 已解决文件:`git add {file}`(禁止 `git add .`
3. 展示 `git diff --cached --stat`,用户确认后完成提交
### 验证
- `cd backend && go build ./...`
- `cd frontend && bun run build`
- 失败则提供回退选项:`git reset --soft HEAD~1``git reset --hard merge-before-{分支名}`,由用户决定
### 断点
每个分支完成后询问是否继续,暂停则记录进度。
## 5. 清理
### 删除分支
输出合并结果表,逐个确认删除:
- 本地:`git branch -d {分支}`
- 远端:独立确认后 `git push origin --delete {分支}`
### 锚点
询问是否保留安全锚点标签。
### 总结
输出:目标分支、安全锚点标签、成功/失败/跳过数量、冲突解决文件数、已删除分支、保留分支及原因。

View File

@@ -0,0 +1,58 @@
请对 openspec/specs/ 下所有规范文件进行审查和整理,按以下流程执行。
## 约束
- 规范描述"应该是什么"不含实现细节具体文件路径、代码引用和变更记录ADDED/MODIFIED、"移除以下列"等措辞)
- 每批重构操作执行前用提问工具获得用户确认
- 仅删除内容已完全覆盖在其他规范中的冗余规范,非冗余内容仅迁移/合并/重命名
## 1. 收集
并行读取以下内容:
- `openspec/specs/` 每个子目录的 `spec.md`
- 项目源码backend/、frontend/),理解实际实现
- `openspec/config.yaml`,了解项目约束
## 2. 审查
将每个规范与代码和命名约定对比,按以下维度逐项检查:
| 维度 | 检查点 |
| ---- | ------ |
| 过时 | 描述的功能/组件是否仍存在于代码引用的路径、类名、API 是否一致;交互流程是否匹配当前实现 |
| 重复 | 不同规范是否描述同一组件/功能/场景(场景级或概念级) |
| 错位 | 场景是否放错了功能域Requirement 是否归属错误的规范 |
| 合并 | 同一主题是否分散在多个规范;某个规范是否是另一个的子集 |
| 命名 | 是否准确反映内容;是否符合命名约定(见下);是否暴露可搜索的业务关键词 |
| 格式 | 是否使用 SHALL/WHEN/THEN是否混入变更记录或实现细节是否有空目录 |
### 命名约定
| 类型 | 模式 | 示例 |
| ---- | ---- | ---- |
| 平台专属 | `{平台}-{功能}` | admin-platform、console-my-skills |
| 跨平台组件 | `{类别}` | component-library、layout-system |
| 技能领域 | `skill-{方面}` | skill-market、skill-status-rules |
| 业务功能 | `{业务名词}` | account-management、chat-scenarios |
命名原则统一平台前缀admin-/console-/developer-、统一领域前缀skill-、2-3 词、避免泛化词display/basic/general/info/data和实现模式词crud/list/table
## 3. 报告
输出分析结果:
1. **问题总览表**:问题类型 × 涉及规范数
2. **逐项分析**:每个有问题的规范,说明具体问题和建议(涉及文件、冲突点、推荐操作)
3. **重构方案**:按优先级分批:
- P0删除空目录和完全冗余规范
- P1合并重复/子集规范到主规范
- P2重命名不精准的规范、拆分错位内容
- P3修正与代码不匹配的描述、清理实现细节和变更记录
4. **重构后目录结构**:预期的新 specs/ 目录树
## 4. 执行
逐批执行重构P0→P3每批执行前
- 展示该批次的具体操作列表(源路径 → 目标路径、操作类型)
- 用提问工具获得用户确认
- 执行后确认目录结构完整,规范文件可正常读取

View File

@@ -1,106 +0,0 @@
# 规范文件整理流程
## 使用方式
将下方提示词完整复制给 AI 工具,即可启动一次规范文件的全面审查和整理。
---
## 提示词
```
请对 openspec/specs/ 下的所有规范文件进行审查和整理,按以下流程执行:
## 第一步:全面阅读
1. 逐个读取 openspec/specs/ 下每个子目录的 spec.md理解每个规范的覆盖范围
2. 读取项目源码,理解实际代码实现
3. 读取 openspec/config.yaml了解项目约束和规范
## 第二步:对比分析
将每个规范与实际代码对比,按以下维度逐项检查:
### A. 过时检查
- 规范描述的功能/组件/样式是否在当前代码中仍然存在
- 规范引用的文件路径、类名、API 接口是否与代码一致
- 规范描述的交互流程是否仍是当前的实现方式
### B. 重复检查
- 不同规范是否描述了相同的组件/功能/场景
- 场景级别的重复A 规范的 Scenario 与 B 规范的 Scenario 重复)
- 概念级别的重复A 规范整体描述的就是 B 规范已覆盖的内容)
### C. 错位检查
- A 规范中是否有场景应该属于 B 规范
- 某个 Requirement 是否放在了错误的功能域下
### D. 合并检查
- 描述同一类主题的规范是否分散在多个文件中
- 某个规范是否可以作为子集被另一个更大的规范吸收
### E. 命名检查
- 规范名称是否准确反映其实际内容
- 命名是否遵循统一的前缀约定平台前缀admin- / developer- / console-
- 名称是否便于 AI 工具搜索匹配(暴露关键业务词和组件名)
### F. 格式检查
- 是否使用标准的 SHALL/WHEN/THEN 规范格式
- 是否混入了变更记录(如"移除以下列"、"ADDED Requirements")而非功能规范
- 是否存在空目录
## 第三步:输出分析报告
按以下结构输出:
1. 问题总览表(问题类型 × 涉及规范数)
2. 逐项分析(每个有问题的规范,说明具体问题和建议)
3. 重构方案(删除/合并/重命名/内容调整的具体操作)
4. 重构后的规范目录结构
## 第四步:执行重构
按优先级分批执行:
- P0删除空目录和完全冗余的规范
- P1合并重复/子集规范到主规范中
- P2重命名不精准的规范、拆分错位的内容
- P3修正与代码不匹配的细节描述
每步执行后确认目录结构完整。
## 命名约定
规范目录命名遵循以下规则,确保 AI 工具搜索时能精准匹配:
| 类型 | 命名模式 | 示例 |
|------|---------|------|
| 平台专属功能 | `{平台}-{功能}` | `admin-platform`、`console-my-skills`、`developer-platform` |
| 跨平台组件/架构 | `{类别}` | `component-library`、`layout-system`、`design-tokens` |
| 技能领域 | `skill-{方面}` | `skill-market`、`skill-status-rules`、`skill-version-management` |
| 业务功能 | `{业务名词}` | `account-management`、`chat-scenarios` |
命名原则(提升 AI 检索命中率):
- 名称中暴露可搜索的业务关键词(如 skill、modal、toast、account
- 同一平台的功能使用统一前缀admin- / console- / developer-
- 同一领域的功能使用统一领域词前缀skill-
- 避免泛化词display → rules/behaviorbasic → 删掉general → 删掉)
- 避免实现模式词crud、list、table而使用业务领域词
- 避免同一关键词在不同规范中重复出现导致歧义(如 layout 只出现在一个规范名中)
- 长度控制在 2-3 个词去掉不影响检索的冗余词info、data 等)
```
---
## 补充说明
### 审查时的判断边界
- **规范 vs 代码**:规范描述"应该是什么",不描述"代码怎么写"。如果规范中出现了具体文件路径(如 `src/data/adminData.js`),通常是实现细节而非规范,应该清理
- **规范 vs 变更记录**:规范用 SHALL/WHEN/THEN 格式描述功能需求。如果出现"移除以下列"、"保持现有样式"、"ADDED/MODIFIED Requirements"等措辞,说明混入了变更指令,需要改写
- **规范 vs 文档**:规范不替代 README 或开发文档,不需要描述项目背景、技术选型等宏观信息
### 建议的定期审查节奏
- 每完成一批功能变更后,对照新代码检查相关规范是否需要更新
- 规范数量超过 30 个时,建议做一次全面审查
- 新增规范前,先搜索现有规范名称和内容,确认是否有可复用/扩展的规范

9
embedfs/embedfs.go Normal file
View File

@@ -0,0 +1,9 @@
package embedfs
import "embed"
//go:embed assets/*
var Assets embed.FS
//go:embed frontend-dist/*
var FrontendDist embed.FS

3
embedfs/go.mod Normal file
View File

@@ -0,0 +1,3 @@
module nex/embedfs
go 1.26.2

1
frontend/.env.desktop Normal file
View File

@@ -0,0 +1 @@
VITE_API_BASE=

View File

@@ -14,6 +14,68 @@ AI 网关管理前端,提供供应商配置和用量统计界面。
- **样式**: SCSS Modules禁止使用纯 CSS
- **测试**: Vitest + React Testing Library + Playwright
## API 层
### 字段转换机制
后端使用 `snake_case`,前端使用 `camelCase`API 层自动转换:
```typescript
// 发送请求时camelCase → snake_case
toApi({ providerId: "openai" }) // → { provider_id: "openai" }
// 接收响应时snake_case → camelCase
fromApi({ provider_id: "openai" }) // → { providerId: "openai" }
```
### 统一请求函数
```typescript
export async function request<T>(method: string, path: string, body?: unknown): Promise<T>
```
- 自动处理字段转换
- 自动处理 204 响应(无 body
- 抛出 `ApiError` 包含 `status``code``message`
### 错误处理
```typescript
class ApiError extends Error {
status: number; // HTTP 状态码
code?: string; // 业务错误码
message: string; // 错误消息
}
```
## TanStack Query 模式
### Query Keys 定义
```typescript
// src/hooks/useProviders.ts
export const providerKeys = {
all: ['providers'] as const,
};
// src/hooks/useModels.ts
export const modelKeys = {
all: ['models'] as const,
byProvider: (providerId: string) => [...modelKeys.all, { providerId }] as const,
};
```
### Mutation 使用
```typescript
const mutation = useMutation({
mutationFn: createProvider,
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: providerKeys.all });
},
});
```
## 项目结构
```
@@ -25,7 +87,7 @@ frontend/
│ │ ├── models.ts # Model CRUD
│ │ └── stats.ts # Stats 查询
│ ├── components/
│ │ └── AppLayout/ # 顶部导航布局
│ │ └── AppLayout/ # 侧边栏导航布局
│ ├── hooks/ # TanStack Query hooks
│ │ ├── useProviders.ts
│ │ ├── useModels.ts
@@ -33,6 +95,7 @@ frontend/
│ ├── pages/
│ │ ├── Providers/ # 供应商管理(含内嵌模型管理)
│ │ ├── Stats/ # 用量统计
│ │ ├── Settings/ # 设置(开发中)
│ │ └── NotFound.tsx
│ ├── routes/
│ │ └── index.tsx # 路由配置
@@ -125,11 +188,50 @@ bun run test:e2e
- 按模型筛选
- 按日期范围筛选DatePicker.RangePicker
## 测试策略
### 目录结构
```
__tests__/
├── setup.ts # 测试配置happy-dom
├── api/ # API 层测试
│ └── client.test.ts
├── hooks/ # TanStack Query Hook 测试
│ ├── useProviders.test.ts
│ └── useModels.test.ts
└── components/ # 组件测试
└── AppLayout.test.tsx
```
### E2E 测试
- 位于 `e2e/` 目录
- 使用 Playwright
- 自动启动后端服务(临时端口 19026
- 配置文件:`playwright.config.ts`
### Mock 策略
- API 层测试使用 MSWMock Service Worker
- Hook 测试使用 `@testing-library/react-hooks`
- 组件测试使用 `@testing-library/react`
## 环境变量
| 变量 | 开发环境 | 生产环境 | 说明 |
|------|----------|----------|------|
| `VITE_API_BASE` | (空) | `/api` | API 基础路径,空则走 Vite proxy |
**E2E 测试特有**
- `NEX_BACKEND_PORT` - E2E 后端端口(默认 19026
- `NEX_E2E_TEMP_DIR` - E2E 临时目录
## 开发规范
- 所有样式使用 SCSS禁止使用纯 CSS 文件
- 组件级样式使用 SCSS Modules*.module.scss
- 图标优先使用 @ant-design/icons
- 图标优先使用 TDesign 图标tdesign-icons-react
- TypeScript strict 模式,禁止 any 类型
- API 层自动处理 snake_case ↔ camelCase 字段转换
- 使用路径别名 `@/` 引用 src 目录
@@ -143,4 +245,4 @@ bun run test:e2e
1.`src/pages/` 创建页面目录和组件
2.`src/hooks/` 创建对应的 TanStack Query hook
3.`src/routes/index.tsx` 添加路由
4.`src/components/AppLayout/index.tsx`menuItems 添加导航项
4.`src/components/AppLayout/index.tsx`Menu 中添加 MenuItem

View File

@@ -1,4 +1,5 @@
import fs from 'node:fs'
import os from 'node:os'
import path from 'node:path'
import initSqlite from 'sql.js'
@@ -26,6 +27,18 @@ export interface SeedStatsInput {
date: string
}
export async function clearDatabase(
request: import('@playwright/test').APIRequestContext,
) {
const providers = await request.get(`${API_BASE}/api/providers`)
if (providers.ok()) {
const data = await providers.json()
for (const p of data) {
await request.delete(`${API_BASE}/api/providers/${p.id}`)
}
}
}
export async function seedProvider(
request: import('@playwright/test').APIRequestContext,
data: SeedProviderInput,
@@ -64,12 +77,14 @@ export async function seedModel(
}
export async function seedUsageStats(statsData: SeedStatsInput[]) {
const tempDir = process.env.NEX_E2E_TEMP_DIR
if (!tempDir) {
throw new Error('NEX_E2E_TEMP_DIR not set')
}
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
const dbPath = path.join(tempDir, 'test.db')
if (!fs.existsSync(dbPath)) {
throw new Error(`Database file not found at ${dbPath}. Backend may not have created it yet.`)
}
const SQL = await initSqlite()
const buf = fs.readFileSync(dbPath)
const db = new SQL.Database(buf)

View File

@@ -1,8 +1,10 @@
import fs from 'node:fs'
import os from 'node:os'
import path from 'node:path'
async function globalSetup() {
const tempDir = process.env.NEX_E2E_TEMP_DIR
if (tempDir && fs.existsSync(tempDir)) {
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
if (fs.existsSync(tempDir)) {
console.log(`E2E temp dir: ${tempDir}`)
}
}

View File

@@ -1,8 +1,10 @@
import fs from 'node:fs'
import os from 'node:os'
import path from 'node:path'
async function globalTeardown() {
const tempDir = process.env.NEX_E2E_TEMP_DIR
if (tempDir && fs.existsSync(tempDir)) {
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
if (fs.existsSync(tempDir)) {
await new Promise((resolve) => setTimeout(resolve, 500))
try {
fs.rmSync(tempDir, { recursive: true, force: true })

View File

@@ -1,5 +1,5 @@
import { test, expect } from '@playwright/test'
import { API_BASE } from './fixtures'
import { API_BASE, clearDatabase } from './fixtures'
let uid = Date.now()
function nextId() {
@@ -10,6 +10,7 @@ function modelFormInputs(page: import('@playwright/test').Page) {
const dialog = page.locator('.t-dialog:visible')
return {
modelName: dialog.locator('input[placeholder="例如: gpt-4o"]'),
providerSelect: dialog.locator('.t-select'),
saveBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '保存' }),
cancelBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '取消' }),
}
@@ -19,6 +20,7 @@ test.describe('模型管理', () => {
let providerId: string
test.beforeEach(async ({ page, request }) => {
await clearDatabase(request)
providerId = nextId()
await request.post(`${API_BASE}/api/providers`, {
data: {
@@ -36,24 +38,27 @@ test.describe('模型管理', () => {
})
test('应能展开供应商查看模型空状态', async ({ page }) => {
await page.locator('.t-table__expandable-icon').first().click()
await page.locator('.t-table__expand-box').first().click()
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
await expect(page.getByText('暂无模型,点击上方按钮添加')).toBeVisible()
})
test('应能为供应商添加模型', async ({ page }) => {
await page.locator('.t-table__expandable-icon').first().click()
await page.locator('.t-table__expand-box').first().click()
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
await page.locator('.t-dialog:visible').waitFor({ state: 'hidden', timeout: 3000 }).catch(() => {})
await page.locator('.t-table__expanded-row button:has-text("添加模型")').first().click()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
const inputs = modelFormInputs(page)
await inputs.modelName.fill('gpt_4_turbo')
const responsePromise = page.waitForResponse(resp => resp.url().includes('/api/models') && resp.request().method() === 'POST')
await inputs.saveBtn.click()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4_turbo')).toBeVisible()
await responsePromise
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4_turbo', { exact: true })).toBeVisible({ timeout: 5000 })
})
test('应显示统一模型 ID', async ({ page, request }) => {
@@ -68,7 +73,7 @@ test.describe('模型管理', () => {
await page.reload()
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
await page.locator('.t-table__expandable-icon').first().click()
await page.locator('.t-table__expand-box').first().click()
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
await expect(page.locator('.t-table__expanded-row').getByText(`${providerId}/claude_3`)).toBeVisible()
@@ -86,7 +91,7 @@ test.describe('模型管理', () => {
await page.reload()
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
await page.locator('.t-table__expandable-icon').first().click()
await page.locator('.t-table__expand-box').first().click()
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
await page.locator('.t-table__expanded-row button:has-text("编辑")').first().click()
@@ -95,10 +100,11 @@ test.describe('模型管理', () => {
const inputs = modelFormInputs(page)
await inputs.modelName.clear()
await inputs.modelName.fill('gpt_4o')
const responsePromise = page.waitForResponse(resp => resp.url().includes('/api/models') && resp.request().method() === 'PUT')
await inputs.saveBtn.click()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4o')).toBeVisible()
await responsePromise
await expect(page.locator('.t-table__expanded-row').getByText('gpt_4o', { exact: true })).toBeVisible({ timeout: 5000 })
})
test('应能删除模型', async ({ page, request }) => {
@@ -113,13 +119,13 @@ test.describe('模型管理', () => {
await page.reload()
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
await page.locator('.t-table__expandable-icon').first().click()
await page.locator('.t-table__expand-box').first().click()
await expect(page.locator('.t-table__expanded-row').first()).toBeVisible()
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model')).toBeVisible()
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model', { exact: true })).toBeVisible()
await page.locator('.t-table__expanded-row button:has-text("删除")').first().click()
await expect(page.getByText(/确定要删除/)).toBeVisible()
await page.locator('.t-popconfirm').getByRole('button', { name: '确定' }).click()
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model')).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__expanded-row').getByText('to_delete_model', { exact: true })).not.toBeVisible({ timeout: 5000 })
})
})

View File

@@ -1,4 +1,5 @@
import { test, expect } from '@playwright/test'
import { clearDatabase } from './fixtures'
let uid = Date.now()
@@ -11,15 +12,17 @@ function formInputs(page: import('@playwright/test').Page) {
return {
id: dialog.locator('input[placeholder="例如: openai"]'),
name: dialog.locator('input[placeholder="例如: OpenAI"]'),
apiKey: dialog.locator('input[type="password"]'),
apiKey: dialog.locator('input[placeholder="sk-..."]'),
baseUrl: dialog.locator('input[placeholder="例如: https://api.openai.com/v1"]'),
protocol: dialog.locator('.t-select'),
saveBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '保存' }),
cancelBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '取消' }),
}
}
test.describe('供应商管理', () => {
test.beforeEach(async ({ page }) => {
test.beforeEach(async ({ page, request }) => {
await clearDatabase(request)
await page.goto('/providers')
await expect(page.getByRole('heading', { name: '供应商管理' })).toBeVisible()
})
@@ -34,12 +37,14 @@ test.describe('供应商管理', () => {
await inputs.name.fill('Test Provider')
await inputs.apiKey.fill('sk_test_key_12345')
await inputs.baseUrl.fill('https://api.openai.com/v1')
await inputs.protocol.click()
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
await page.locator('.t-select__dropdown .t-select-option').first().click()
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
await inputs.saveBtn.click()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__body td').getByText(testId)).toBeVisible()
await expect(page.locator('.t-table__body td').getByText('Test Provider')).toBeVisible()
await expect(page.locator('.t-table__body').getByText('Test Provider')).toBeVisible({ timeout: 10000 })
})
test('应能编辑供应商并验证更新生效', async ({ page }) => {
@@ -51,8 +56,15 @@ test.describe('供应商管理', () => {
await inputs.name.fill('Before Edit')
await inputs.apiKey.fill('sk_key')
await inputs.baseUrl.fill('https://api.example.com/v1')
await inputs.protocol.click()
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
await page.locator('.t-select__dropdown .t-select-option').first().click()
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
const responsePromise = page.waitForResponse(resp => resp.url().includes('/api/providers') && resp.request().method() === 'POST')
await inputs.saveBtn.click()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
await responsePromise
await expect(page.locator('.t-table__body').getByText('Before Edit')).toBeVisible({ timeout: 5000 })
await page.locator('.t-table__body button:has-text("编辑")').first().click()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
@@ -60,10 +72,11 @@ test.describe('供应商管理', () => {
const editInputs = formInputs(page)
await editInputs.name.clear()
await editInputs.name.fill('After Edit')
const updateResponsePromise = page.waitForResponse(resp => resp.url().includes('/api/providers') && resp.request().method() === 'PUT')
await editInputs.saveBtn.click()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__body td').getByText('After Edit')).toBeVisible()
await updateResponsePromise
await expect(page.locator('.t-table__body').getByText('After Edit')).toBeVisible({ timeout: 5000 })
})
test('应能删除供应商并验证消失', async ({ page }) => {
@@ -75,18 +88,20 @@ test.describe('供应商管理', () => {
await inputs.name.fill('To Delete')
await inputs.apiKey.fill('sk_key')
await inputs.baseUrl.fill('https://api.example.com/v1')
await inputs.protocol.click()
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
await page.locator('.t-select__dropdown .t-select-option').first().click()
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
await inputs.saveBtn.click()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__body').getByText('To Delete')).toBeVisible({ timeout: 10000 })
await page.locator('.t-table__body button:has-text("删除")').first().click()
await expect(page.getByText('确定要删除这个供应商吗?')).toBeVisible()
await page.locator('.t-popconfirm').getByRole('button', { name: '确定' }).click()
await expect(page.getByText('确定要删除这个供应商吗?')).not.toBeVisible({ timeout: 3000 })
await expect(page.locator('.t-table__body td').getByText(testId)).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__body').getByText('To Delete')).not.toBeVisible({ timeout: 5000 })
})
test('应正确脱敏显示 API Key', async ({ page }) => {
test('应正确显示完整 API Key', async ({ page }) => {
const testId = nextId()
await page.getByRole('button', { name: '添加供应商' }).click()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
@@ -95,9 +110,13 @@ test.describe('供应商管理', () => {
await inputs.name.fill('Mask Test')
await inputs.apiKey.fill('sk_abcdefghijklmnopqrstuvwxyz')
await inputs.baseUrl.fill('https://api.example.com/v1')
await inputs.protocol.click()
await page.waitForSelector('.t-select__dropdown', { timeout: 3000 })
await page.locator('.t-select__dropdown .t-select-option').first().click()
await page.waitForSelector('.t-select__dropdown', { state: 'hidden', timeout: 3000 })
await inputs.saveBtn.click()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible({ timeout: 5000 })
await expect(page.locator('.t-table__body').getByText('Mask Test')).toBeVisible({ timeout: 10000 })
await expect(page.locator('.t-table__body')).toContainText('***stuv')
await expect(page.locator('.t-table__body')).toContainText('sk_abcdefghijklmnopqrstuvwxyz')
})
})

View File

@@ -1,8 +1,9 @@
import { test, expect } from '@playwright/test'
import { API_BASE, seedUsageStats } from './fixtures'
import { API_BASE, seedUsageStats, clearDatabase } from './fixtures'
test.describe('统计概览', () => {
test.beforeAll(async ({ request }) => {
await clearDatabase(request)
const p1 = `sp1_${Date.now()}`
const p2 = `sp2_${Date.now()}`
process.env._STATS_P1 = p1
@@ -65,6 +66,7 @@ test.describe('统计概览', () => {
test.describe('统计筛选', () => {
test.beforeAll(async ({ request }) => {
await clearDatabase(request)
const p1 = `fp1_${Date.now()}`
const p2 = `fp2_${Date.now()}`
process.env._FILTER_P1 = p1

View File

@@ -5,7 +5,7 @@ function formInputs(page: import('@playwright/test').Page) {
return {
id: dialog.locator('input[placeholder="例如: openai"]'),
name: dialog.locator('input[placeholder="例如: OpenAI"]'),
apiKey: dialog.locator('input[type="password"]'),
apiKey: dialog.locator('input[placeholder="sk-..."]'),
baseUrl: dialog.locator('input[placeholder="例如: https://api.openai.com/v1"]'),
saveBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '保存' }),
cancelBtn: dialog.locator('.t-dialog__footer').getByRole('button', { name: '取消' }),
@@ -20,7 +20,7 @@ test.describe('供应商表单验证', () => {
test('应显示必填字段验证', async ({ page }) => {
await page.getByRole('button', { name: '添加供应商' }).click()
await expect(page.locator('.t-dialog')).toBeVisible()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
await formInputs(page).saveBtn.click()
@@ -32,7 +32,7 @@ test.describe('供应商表单验证', () => {
test('应验证URL格式', async ({ page }) => {
await page.getByRole('button', { name: '添加供应商' }).click()
await expect(page.locator('.t-dialog')).toBeVisible()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
const inputs = formInputs(page)
await inputs.id.fill('test_url')
@@ -46,17 +46,17 @@ test.describe('供应商表单验证', () => {
test('取消后表单应重置', async ({ page }) => {
await page.getByRole('button', { name: '添加供应商' }).click()
await expect(page.locator('.t-dialog')).toBeVisible()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
let inputs = formInputs(page)
await inputs.id.fill('should_be_reset')
await inputs.name.fill('Should Be Reset')
await inputs.cancelBtn.click()
await expect(page.locator('.t-dialog')).not.toBeVisible()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible()
await page.getByRole('button', { name: '添加供应商' }).click()
await expect(page.locator('.t-dialog')).toBeVisible()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
inputs = formInputs(page)
await expect(inputs.id).toHaveValue('')
@@ -65,11 +65,11 @@ test.describe('供应商表单验证', () => {
test('快速连续点击只打开一个对话框', async ({ page }) => {
await page.getByRole('button', { name: '添加供应商' }).click()
await expect(page.locator('.t-dialog')).toBeVisible()
await expect(page.locator('.t-dialog:visible')).toBeVisible()
expect(await page.locator('.t-dialog').count()).toBe(1)
expect(await page.locator('.t-dialog:visible').count()).toBe(1)
await formInputs(page).cancelBtn.click()
await expect(page.locator('.t-dialog')).not.toBeVisible()
await expect(page.locator('.t-dialog:visible')).not.toBeVisible()
})
})

View File

@@ -8,7 +8,10 @@ const __filename = fileURLToPath(import.meta.url)
const __dirname = path.dirname(__filename)
const E2E_PORT = 19026
const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'nex-e2e-'))
const tempDir = path.join(os.tmpdir(), 'nex-e2e')
if (!fs.existsSync(path.join(tempDir, 'test.db'))) {
fs.rmSync(tempDir, { recursive: true, force: true })
}
const dbPath = path.join(tempDir, 'test.db')
const logPath = path.join(tempDir, 'log')
@@ -24,11 +27,12 @@ export default defineConfig({
fullyParallel: false,
forbidOnly: !!process.env.CI,
retries: process.env.CI ? 2 : 0,
workers: process.env.CI ? 1 : undefined,
workers: 1,
reporter: 'html',
use: {
baseURL: 'http://localhost:5173',
trace: 'on-first-retry',
storageState: undefined,
},
projects: [
{
@@ -50,6 +54,9 @@ export default defineConfig({
command: 'bun run dev',
url: 'http://localhost:5173',
reuseExistingServer: false,
env: {
NEX_BACKEND_PORT: String(E2E_PORT),
},
},
],
})

View File

@@ -16,7 +16,10 @@ const queryClient = new QueryClient({
function App() {
return (
<QueryClientProvider client={queryClient}>
<ConfigProvider globalConfig={{}}>
<ConfigProvider globalConfig={{
animation: { include: ['ripple', 'expand', 'fade'] },
table: { size: 'medium' },
}}>
<BrowserRouter>
<AppRoutes />
</BrowserRouter>

View File

@@ -125,9 +125,6 @@ describe('ModelForm', () => {
const dialog = getDialog();
expect(within(dialog).getByText('编辑模型')).toBeInTheDocument();
// Check that unified ID field is displayed
expect(within(dialog).getByText('统一模型 ID')).toBeInTheDocument();
// Check model name input
const modelNameInput = within(dialog).getByPlaceholderText('例如: gpt-4o') as HTMLInputElement;
expect(modelNameInput.value).toBe('gpt-4o');

View File

@@ -63,13 +63,16 @@ describe('ProviderForm', () => {
const baseUrlInput = within(dialog).getByPlaceholderText('例如: https://api.openai.com/v1') as HTMLInputElement;
expect(baseUrlInput.value).toBe('https://api.openai.com/v1');
const apiKeyInput = within(dialog).getByPlaceholderText('sk-...') as HTMLInputElement;
expect(apiKeyInput.value).toBe('sk-old-key');
});
it('shows API Key label variant in edit mode', () => {
it('shows API Key label in edit mode', () => {
render(<ProviderForm {...defaultProps} provider={mockProvider} />);
const dialog = getDialog();
expect(within(dialog).getByText('API Key(留空则不修改)')).toBeInTheDocument();
expect(within(dialog).getByText('API Key')).toBeInTheDocument();
});
it('shows validation error messages for required fields', async () => {

View File

@@ -1,4 +1,4 @@
import { render, screen, fireEvent } from '@testing-library/react';
import { render, screen } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { describe, it, expect, vi } from 'vitest';
import { ProviderTable } from '@/pages/Providers/ProviderTable';
@@ -48,19 +48,18 @@ const defaultProps = {
};
describe('ProviderTable', () => {
it('renders provider list with name, baseUrl, masked apiKey, and status tags', () => {
it('renders provider list with name, baseUrl, apiKey, and status tags', () => {
render(<ProviderTable {...defaultProps} />);
expect(screen.getByText('供应商列表')).toBeInTheDocument();
// Check that provider names appear (they will appear in both name column and potentially protocol column)
expect(screen.getAllByText('OpenAI').length).toBeGreaterThan(0);
expect(screen.getByText('https://api.openai.com/v1')).toBeInTheDocument();
expect(screen.getByText('****5678')).toBeInTheDocument();
expect(screen.getByText('sk-abcdefgh12345678')).toBeInTheDocument();
expect(screen.getAllByText('Anthropic').length).toBeGreaterThan(0);
expect(screen.getByText('https://api.anthropic.com')).toBeInTheDocument();
expect(screen.getByText('****test')).toBeInTheDocument();
expect(screen.getByText('sk-ant-test')).toBeInTheDocument();
const enabledTags = screen.getAllByText('启用');
const disabledTags = screen.getAllByText('禁用');
@@ -77,7 +76,7 @@ describe('ProviderTable', () => {
expect(container.querySelector('.t-card__body')).toBeInTheDocument();
});
it('renders short api keys fully masked', () => {
it('renders short api keys directly', () => {
const shortKeyProvider: Provider[] = [
{
...mockProviders[0],
@@ -88,7 +87,7 @@ describe('ProviderTable', () => {
];
render(<ProviderTable {...defaultProps} providers={shortKeyProvider} />);
expect(screen.getByText('****')).toBeInTheDocument();
expect(screen.getByText('ab')).toBeInTheDocument();
});
it('calls onAdd when clicking "添加供应商" button', async () => {

View File

@@ -37,55 +37,19 @@ describe('StatCards', () => {
expect(screen.getByText('今日请求量')).toBeInTheDocument();
});
it('calculates total requests correctly', () => {
render(<StatCards stats={mockStats} />);
const totalRequests = mockStats.reduce((sum, s) => sum + s.requestCount, 0);
expect(screen.getByText(totalRequests.toString())).toBeInTheDocument();
});
it('calculates active models correctly', () => {
render(<StatCards stats={mockStats} />);
const activeModels = new Set(mockStats.map((s) => s.modelName)).size;
const valueElements = screen.getAllByText(activeModels.toString());
expect(valueElements.length).toBeGreaterThan(0);
});
it('calculates active providers correctly', () => {
render(<StatCards stats={mockStats} />);
const activeProviders = new Set(mockStats.map((s) => s.providerId)).size;
const valueElements = screen.getAllByText(activeProviders.toString());
expect(valueElements.length).toBeGreaterThan(0);
});
it('renders with empty stats', () => {
render(<StatCards stats={[]} />);
expect(screen.getByText('总请求量')).toBeInTheDocument();
const zeroValues = screen.getAllByText('0');
expect(zeroValues.length).toBeGreaterThan(0);
expect(screen.getByText('活跃模型数')).toBeInTheDocument();
expect(screen.getByText('活跃供应商数')).toBeInTheDocument();
expect(screen.getByText('今日请求量')).toBeInTheDocument();
});
it('calculates today requests correctly', () => {
const today = new Date().toISOString().split('T')[0];
const statsWithToday: UsageStats[] = [
...mockStats,
{
id: 4,
providerId: 'openai',
modelName: 'gpt-4o',
requestCount: 50,
date: today,
},
];
it('renders suffix units', () => {
render(<StatCards stats={mockStats} />);
render(<StatCards stats={statsWithToday} />);
const todayRequests = statsWithToday
.filter((s) => s.date === today)
.reduce((sum, s) => sum + s.requestCount, 0);
expect(screen.getByText(todayRequests.toString())).toBeInTheDocument();
expect(screen.getAllByText('次').length).toBeGreaterThan(0);
expect(screen.getAllByText('个').length).toBeGreaterThan(0);
});
});

View File

@@ -6,8 +6,8 @@ import type { UsageStats } from '@/types';
// Mock Recharts components
vi.mock('recharts', () => ({
ResponsiveContainer: vi.fn(({ children }) => <div data-testid="mock-chart-container">{children}</div>),
LineChart: vi.fn(() => <div data-testid="mock-line-chart" />),
Line: vi.fn(() => null),
AreaChart: vi.fn(() => <div data-testid="mock-area-chart" />),
Area: vi.fn(() => null),
XAxis: vi.fn(() => null),
YAxis: vi.fn(() => null),
CartesianGrid: vi.fn(() => null),

View File

@@ -1,5 +1,6 @@
import { Layout, Menu } from 'tdesign-react';
import { ServerIcon, ChartLineIcon, SettingIcon } from 'tdesign-icons-react';
import { useState } from 'react';
import { Layout, Menu, Button } from 'tdesign-react';
import { ServerIcon, ChartLineIcon, SettingIcon, ChevronLeftIcon, ChevronRightIcon } from 'tdesign-icons-react';
import { Outlet, useLocation, useNavigate } from 'react-router';
const { MenuItem } = Menu;
@@ -7,6 +8,7 @@ const { MenuItem } = Menu;
export function AppLayout() {
const location = useLocation();
const navigate = useNavigate();
const [collapsed, setCollapsed] = useState(false);
const getPageTitle = () => {
if (location.pathname === '/providers') return '供应商管理';
@@ -15,10 +17,12 @@ export function AppLayout() {
return 'AI Gateway';
};
const asideWidth = collapsed ? '64px' : '232px';
return (
<Layout style={{ minHeight: '100vh' }}>
<Layout.Aside
width="232px"
width={asideWidth}
style={{
overflow: 'hidden',
height: '100vh',
@@ -28,45 +32,52 @@ export function AppLayout() {
bottom: 0,
}}
>
<div style={{ display: 'flex', flexDirection: 'column', height: '100%' }}>
<div
style={{
<Menu
value={location.pathname}
onChange={(value) => navigate(value as string)}
collapsed={collapsed}
width={['232px', '64px']}
logo={
<div style={{
height: 64,
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
fontSize: '1.25rem',
fontWeight: 600,
flexShrink: 0,
}}
>
AI Gateway
</div>
<Menu
value={location.pathname}
onChange={(value) => navigate(value as string)}
style={{ flex: 1, overflow: 'auto' }}
>
<MenuItem value="/providers" icon={<ServerIcon />}>
</MenuItem>
<MenuItem value="/stats" icon={<ChartLineIcon />}>
</MenuItem>
<MenuItem value="/settings" icon={<SettingIcon />}>
</MenuItem>
</Menu>
</div>
}}>
{!collapsed && 'AI Gateway'}
</div>
}
operations={
<Button
variant="text"
shape="square"
icon={collapsed ? <ChevronRightIcon /> : <ChevronLeftIcon />}
onClick={() => setCollapsed(!collapsed)}
/>
}
style={{ height: '100%' }}
>
<MenuItem value="/providers" icon={<ServerIcon />}>
</MenuItem>
<MenuItem value="/stats" icon={<ChartLineIcon />}>
</MenuItem>
<MenuItem value="/settings" icon={<SettingIcon />}>
</MenuItem>
</Menu>
</Layout.Aside>
<Layout style={{ marginLeft: 232 }}>
<Layout style={{ marginLeft: asideWidth }}>
<Layout.Header
style={{
padding: '0 2rem',
background: '#fff',
background: 'var(--td-bg-color-container)',
display: 'flex',
alignItems: 'center',
borderBottom: '1px solid #e7e7e7',
borderBottom: '1px solid var(--td-component-stroke)',
}}
>
<h1 style={{ margin: 0, fontSize: '1.25rem' }}>{getPageTitle()}</h1>

View File

@@ -7,3 +7,20 @@ body,
width: 100%;
height: 100%;
}
/* TDesign 主题微调 */
:root {
/* 页面背景色 */
--td-bg-color-page: #f5f7fa;
/* 圆角调大 */
--td-radius-default: 6px;
--td-radius-medium: 9px;
--td-radius-large: 12px;
--td-radius-extraLarge: 16px;
/* 系统字体栈 */
--td-font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
"Helvetica Neue", Arial, "Noto Sans", sans-serif, "Apple Color Emoji",
"Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";
}

View File

@@ -1,6 +1,7 @@
import { StrictMode } from 'react'
import { createRoot } from 'react-dom/client'
import 'tdesign-react/es/style/index.css'
import 'tdesign-react/es/_util/react-19-adapter'
import './index.scss'
import App from './App'

View File

@@ -13,8 +13,8 @@ export function NotFound() {
minHeight: '100vh',
padding: '2rem',
}}>
<h1 style={{ fontSize: '6rem', margin: 0, color: '#999' }}>404</h1>
<p style={{ fontSize: '1.25rem', color: '#666', marginBottom: '2rem' }}>
<h1 style={{ fontSize: '6rem', margin: 0, color: 'var(--td-text-color-placeholder)' }}>404</h1>
<p style={{ fontSize: '1.25rem', color: 'var(--td-text-color-secondary)', marginBottom: '2rem' }}>
访
</p>
<Button theme="primary" onClick={() => navigate('/providers')}>

View File

@@ -14,7 +14,7 @@ interface ModelFormProps {
model?: Model;
providerId: string;
providers: Provider[];
onSave: (values: ModelFormValues) => void;
onSave: (values: ModelFormValues) => Promise<void> | void;
onCancel: () => void;
loading: boolean;
}
@@ -63,23 +63,18 @@ export function ModelForm({
<Dialog
header={isEdit ? '编辑模型' : '添加模型'}
visible={open}
placement="center"
width="520px"
closeOnOverlayClick={false}
closeOnEscKeydown={false}
lazy={false}
onConfirm={() => { form?.submit(); return false; }}
onClose={onCancel}
confirmLoading={loading}
confirmBtn="保存"
cancelBtn="取消"
destroyOnClose
>
<Form form={form} layout="vertical" onSubmit={handleSubmit}>
{isEdit && model?.unifiedId && (
<Form.FormItem label="统一模型 ID">
<Input value={model.unifiedId} disabled />
<div style={{ color: '#999', fontSize: 12, marginTop: 4 }}>
provider_id/model_name
</div>
</Form.FormItem>
)}
<Form.FormItem
label="供应商"
name="providerId"

View File

@@ -31,7 +31,11 @@ export function ModelTable({ providerId, onAdd, onEdit }: ModelTableProps) {
colKey: 'enabled',
width: 80,
cell: ({ row }) =>
row.enabled ? <Tag theme="success"></Tag> : <Tag theme="danger"></Tag>,
row.enabled ? (
<Tag theme="success" variant="light" shape="round"></Tag>
) : (
<Tag theme="danger" variant="light" shape="round"></Tag>
),
},
{
title: '操作',
@@ -72,6 +76,7 @@ export function ModelTable({ providerId, onAdd, onEdit }: ModelTableProps) {
data={models}
rowKey="id"
loading={isLoading}
stripe
pagination={undefined}
size="small"
empty="暂无模型,点击上方按钮添加"

View File

@@ -15,7 +15,7 @@ interface ProviderFormValues {
interface ProviderFormProps {
open: boolean;
provider?: Provider;
onSave: (values: ProviderFormValues) => void;
onSave: (values: ProviderFormValues) => Promise<void> | void;
onCancel: () => void;
loading: boolean;
}
@@ -30,26 +30,23 @@ export function ProviderForm({
const [form] = Form.useForm();
const isEdit = !!provider;
// 当弹窗打开或provider变化时设置表单值
useEffect(() => {
if (open && form) {
if (provider) {
// 编辑模式:设置现有值
form.setFieldsValue({
id: provider.id,
name: provider.name,
apiKey: '',
apiKey: provider.apiKey,
baseUrl: provider.baseUrl,
protocol: provider.protocol,
enabled: provider.enabled,
});
} else {
// 新增模式:重置表单
form.reset();
form.setFieldsValue({ enabled: true, protocol: 'openai' });
}
}
}, [open, provider]); // 移除form依赖避免循环
}, [open, provider]);
const handleSubmit = (context: SubmitContext) => {
if (context.validateResult === true && form) {
@@ -62,12 +59,16 @@ export function ProviderForm({
<Dialog
header={isEdit ? '编辑供应商' : '添加供应商'}
visible={open}
placement="center"
width="520px"
closeOnOverlayClick={false}
closeOnEscKeydown={false}
lazy={false}
onConfirm={() => { form?.submit(); return false; }}
onClose={onCancel}
confirmLoading={loading}
confirmBtn="保存"
cancelBtn="取消"
destroyOnClose
>
<Form form={form} layout="vertical" onSubmit={handleSubmit}>
<Form.FormItem label="ID" name="id" rules={[{ required: true, message: '请输入供应商 ID' }]}>
@@ -79,11 +80,11 @@ export function ProviderForm({
</Form.FormItem>
<Form.FormItem
label={isEdit ? 'API Key留空则不修改' : 'API Key'}
label="API Key"
name="apiKey"
rules={isEdit ? [] : [{ required: true, message: '请输入 API Key' }]}
rules={[{ required: true, message: '请输入 API Key' }]}
>
<Input type="password" placeholder="sk-..." autocomplete="current-password" />
<Input placeholder="sk-..." />
</Form.FormItem>
<Form.FormItem

View File

@@ -1,4 +1,4 @@
import { Button, Table, Tag, Popconfirm, Space, Card, Tooltip } from 'tdesign-react';
import { Button, Table, Tag, Popconfirm, Space, Card } from 'tdesign-react';
import type { PrimaryTableCol } from 'tdesign-react/es/table/type';
import type { Provider, Model } from '@/types';
import { ModelTable } from './ModelTable';
@@ -13,12 +13,6 @@ interface ProviderTableProps {
onEditModel: (model: Model) => void;
}
function maskApiKey(key: string | null | undefined): string {
if (!key) return '****';
if (key.length <= 4) return '****';
return `****${key.slice(-4)}`;
}
export function ProviderTable({
providers,
loading,
@@ -45,7 +39,7 @@ export function ProviderTable({
colKey: 'protocol',
width: 100,
cell: ({ row }) => (
<Tag theme={row.protocol === 'openai' ? 'primary' : 'success'}>
<Tag theme={row.protocol === 'openai' ? 'primary' : 'success'} variant="light" shape="round">
{row.protocol === 'openai' ? 'OpenAI' : 'Anthropic'}
</Tag>
),
@@ -53,20 +47,18 @@ export function ProviderTable({
{
title: 'API Key',
colKey: 'apiKey',
width: 120,
ellipsis: true,
cell: ({ row }) => (
<Tooltip content={maskApiKey(row.apiKey)}>
<span>{maskApiKey(row.apiKey)}</span>
</Tooltip>
),
},
{
title: '状态',
colKey: 'enabled',
width: 80,
cell: ({ row }) =>
row.enabled ? <Tag theme="success"></Tag> : <Tag theme="danger"></Tag>,
row.enabled ? (
<Tag theme="success" variant="light" shape="round"></Tag>
) : (
<Tag theme="danger" variant="light" shape="round"></Tag>
),
},
{
title: '操作',
@@ -93,6 +85,8 @@ export function ProviderTable({
return (
<Card
title="供应商列表"
headerBordered
hoverShadow
actions={
<Button theme="primary" onClick={onAdd}>
@@ -104,6 +98,7 @@ export function ProviderTable({
data={providers}
rowKey="id"
loading={loading}
stripe
expandedRow={({ row }) => (
<ModelTable
providerId={row.id}

View File

@@ -50,21 +50,21 @@ export function ProvidersPage() {
open={providerFormOpen}
provider={editingProvider}
loading={createProvider.isPending || updateProvider.isPending}
onSave={(values) => {
if (editingProvider) {
const input: Partial<UpdateProviderInput> = {};
if (values.name !== editingProvider.name) input.name = values.name;
if (values.apiKey) input.apiKey = values.apiKey;
if (values.baseUrl !== editingProvider.baseUrl) input.baseUrl = values.baseUrl;
if (values.enabled !== editingProvider.enabled) input.enabled = values.enabled;
updateProvider.mutate(
{ id: editingProvider.id, input },
{ onSuccess: () => setProviderFormOpen(false) },
);
} else {
createProvider.mutate(values, {
onSuccess: () => setProviderFormOpen(false),
});
onSave={async (values) => {
try {
if (editingProvider) {
const input: Partial<UpdateProviderInput> = {};
if (values.name !== editingProvider.name) input.name = values.name;
if (values.apiKey !== editingProvider.apiKey) input.apiKey = values.apiKey;
if (values.baseUrl !== editingProvider.baseUrl) input.baseUrl = values.baseUrl;
if (values.enabled !== editingProvider.enabled) input.enabled = values.enabled;
await updateProvider.mutateAsync({ id: editingProvider.id, input });
} else {
await createProvider.mutateAsync(values);
}
setProviderFormOpen(false);
} catch {
// 错误已由 hooks 的 onError 处理
}
}}
onCancel={() => setProviderFormOpen(false)}
@@ -76,20 +76,20 @@ export function ProvidersPage() {
providerId={modelFormProviderId}
providers={providers}
loading={createModel.isPending || updateModel.isPending}
onSave={(values) => {
if (editingModel) {
const input: Partial<UpdateModelInput> = {};
if (values.providerId !== editingModel.providerId) input.providerId = values.providerId;
if (values.modelName !== editingModel.modelName) input.modelName = values.modelName;
if (values.enabled !== editingModel.enabled) input.enabled = values.enabled;
updateModel.mutate(
{ id: editingModel.id, input },
{ onSuccess: () => setModelFormOpen(false) },
);
} else {
createModel.mutate(values, {
onSuccess: () => setModelFormOpen(false),
});
onSave={async (values) => {
try {
if (editingModel) {
const input: Partial<UpdateModelInput> = {};
if (values.providerId !== editingModel.providerId) input.providerId = values.providerId;
if (values.modelName !== editingModel.modelName) input.modelName = values.modelName;
if (values.enabled !== editingModel.enabled) input.enabled = values.enabled;
await updateModel.mutateAsync({ id: editingModel.id, input });
} else {
await createModel.mutateAsync(values);
}
setModelFormOpen(false);
} catch {
// 错误已由 hooks 的 onError 处理
}
}}
onCancel={() => setModelFormOpen(false)}

View File

@@ -3,7 +3,7 @@ import { Card } from 'tdesign-react';
export function SettingsPage() {
return (
<Card title="设置">
<div style={{ textAlign: 'center', padding: '40px 0', color: '#999' }}>
<div style={{ textAlign: 'center', padding: '40px 0', color: 'var(--td-text-color-placeholder)' }}>
...
</div>
</Card>

View File

@@ -1,4 +1,5 @@
import { Row, Col, Card, Statistic } from 'tdesign-react';
import { ChartBarIcon, ChartLineIcon, ServerIcon, Calendar1Icon } from 'tdesign-icons-react';
import type { UsageStats } from '@/types';
interface StatCardsProps {
@@ -18,23 +19,55 @@ export function StatCards({ stats }: StatCardsProps) {
return (
<Row gutter={[16, 16]} style={{ marginBottom: 16 }}>
<Col xs={12} md={6}>
<Card>
<Statistic title="总请求量" value={totalRequests} />
<Card bordered={false} hoverShadow>
<Statistic
title="总请求量"
value={totalRequests}
color="blue"
prefix={<ChartBarIcon />}
suffix="次"
animation={{ duration: 800, valueFrom: 0 }}
animationStart
/>
</Card>
</Col>
<Col xs={12} md={6}>
<Card>
<Statistic title="活跃模型数" value={activeModels} />
<Card bordered={false} hoverShadow>
<Statistic
title="活跃模型数"
value={activeModels}
color="green"
prefix={<ChartLineIcon />}
suffix="个"
animation={{ duration: 800, valueFrom: 0 }}
animationStart
/>
</Card>
</Col>
<Col xs={12} md={6}>
<Card>
<Statistic title="活跃供应商数" value={activeProviders} />
<Card bordered={false} hoverShadow>
<Statistic
title="活跃供应商数"
value={activeProviders}
color="orange"
prefix={<ServerIcon />}
suffix="个"
animation={{ duration: 800, valueFrom: 0 }}
animationStart
/>
</Card>
</Col>
<Col xs={12} md={6}>
<Card>
<Statistic title="今日请求量" value={todayRequests} />
<Card bordered={false} hoverShadow>
<Statistic
title="今日请求量"
value={todayRequests}
color="red"
prefix={<Calendar1Icon />}
suffix="次"
animation={{ duration: 800, valueFrom: 0 }}
animationStart
/>
</Card>
</Col>
</Row>

View File

@@ -78,8 +78,8 @@ export function StatsTable({
};
return (
<Card title="统计数据">
<Space style={{ marginBottom: 16 }}>
<Card title="统计数据" headerBordered hoverShadow>
<Space style={{ marginBottom: 16 }} size="medium" breakLine>
<Select
clearable
placeholder="所有供应商"
@@ -107,6 +107,7 @@ export function StatsTable({
data={stats}
rowKey="id"
loading={loading}
stripe
pagination={{ pageSize: 20 }}
empty="暂无统计数据"
/>

View File

@@ -1,12 +1,13 @@
import { Card } from 'tdesign-react';
import { LineChart, Line, XAxis, YAxis, CartesianGrid, ResponsiveContainer, Tooltip } from 'recharts';
import { AreaChart, Area, XAxis, YAxis, CartesianGrid, ResponsiveContainer, Tooltip } from 'recharts';
import type { UsageStats } from '@/types';
interface UsageChartProps {
stats: UsageStats[];
isLoading?: boolean;
}
export function UsageChart({ stats }: UsageChartProps) {
export function UsageChart({ stats, isLoading }: UsageChartProps) {
const chartData = Object.entries(
stats.reduce<Record<string, number>>((acc, s) => {
acc[s.date] = (acc[s.date] || 0) + s.requestCount;
@@ -17,25 +18,31 @@ export function UsageChart({ stats }: UsageChartProps) {
.sort((a, b) => a.date.localeCompare(b.date));
return (
<Card title="请求趋势" style={{ marginBottom: 16 }}>
<Card title="请求趋势" headerBordered hoverShadow loading={isLoading} style={{ marginBottom: 16 }}>
{chartData.length > 0 ? (
<ResponsiveContainer width="100%" height={300}>
<LineChart data={chartData}>
<CartesianGrid strokeDasharray="3 3" />
<AreaChart data={chartData}>
<defs>
<linearGradient id="requestGradient" x1="0" y1="0" x2="0" y2="1">
<stop offset="0%" stopColor="#0052D9" stopOpacity={0.4} />
<stop offset="100%" stopColor="#0052D9" stopOpacity={0} />
</linearGradient>
</defs>
<CartesianGrid strokeDasharray="3 3" stroke="#e8e8e8" />
<XAxis dataKey="date" />
<YAxis />
<Tooltip />
<Line
<Area
type="monotone"
dataKey="requestCount"
stroke="#0052D9"
strokeWidth={2}
dot={{ fill: '#0052D9' }}
fill="url(#requestGradient)"
/>
</LineChart>
</AreaChart>
</ResponsiveContainer>
) : (
<div style={{ textAlign: 'center', padding: '40px 0', color: '#999' }}>
<div style={{ textAlign: 'center', padding: '40px 0', color: 'var(--td-text-color-placeholder)' }}>
</div>
)}

View File

@@ -27,7 +27,7 @@ export function StatsPage() {
return (
<div>
<StatCards stats={stats} />
<UsageChart stats={stats} />
<UsageChart stats={stats} isLoading={isLoading} />
<StatsTable
providers={providers}
stats={stats}

6
go.work Normal file
View File

@@ -0,0 +1,6 @@
go 1.26.2
use (
backend
embedfs
)

Some files were not shown because too many files have changed in this diff Show More