Compare commits
73 Commits
aea360bce8
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| f5c82b6980 | |||
| 9105a36097 | |||
| f1ee646ca4 | |||
| b9b487c591 | |||
| 4c62c071fb | |||
| b2e9dd8b7f | |||
| d143c5f3df | |||
| 4eebdfb8db | |||
| b517946585 | |||
| 4ddae6be74 | |||
| 195762ff97 | |||
| bcf5ca89e5 | |||
| 365943e4c4 | |||
| 4c6b49099d | |||
| 4c78ab6cc8 | |||
| 52007c9461 | |||
| 086dd1fed7 | |||
| 1d7e839b49 | |||
| fa7babf13b | |||
| 280099b89c | |||
| 0a92a25451 | |||
| 8c075194e5 | |||
| 53e477d383 | |||
| 1522c87c74 | |||
| e0d05c9869 | |||
| 5b401e29cb | |||
| 65ac9f740a | |||
| 58ebcaa299 | |||
| 5b765c8b5e | |||
| b3258e76df | |||
| 64dc66afa6 | |||
| 15f08ee2ca | |||
| 380586afa6 | |||
| ebb70809bf | |||
| 7399afbc5c | |||
| c0669e4b07 | |||
| 05c04091b3 | |||
| 0b05e08705 | |||
| df253559a5 | |||
| 669cbb8c51 | |||
| 5ae9d85272 | |||
| 72aebef625 | |||
| f5e45d032e | |||
| b03e5f809f | |||
| ec563aaa16 | |||
| 873f09d3bf | |||
| 5e7267db07 | |||
| 7b28cee7a1 | |||
| 934c8dea77 | |||
| 7d91fe345e | |||
| 4e86adffb7 | |||
| 5d58acf5a6 | |||
| 81dcecb723 | |||
| 141f5f886f | |||
| 7fa5af483b | |||
| f488b9cc15 | |||
| 59179094ed | |||
| 4fc5fb4764 | |||
| feff97acbd | |||
| b7e205f4b6 | |||
| 24f03595a7 | |||
| 395887667d | |||
| 44d6af026a | |||
| 6e11ada42c | |||
| da790db75b | |||
| e1af978c56 | |||
| 980875ecf3 | |||
| 7f0f831226 | |||
| f3a207fa16 | |||
| 56ecc73d1b | |||
| 1ae9336cbe | |||
| 3fa5827de3 | |||
| cfb0edf802 |
7
.editorconfig
Normal file
7
.editorconfig
Normal file
@@ -0,0 +1,7 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* text=auto eol=lf
|
||||
90
.gitignore
vendored
90
.gitignore
vendored
@@ -317,10 +317,98 @@ Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### Python.gitignore ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.python-version
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
# Pyre
|
||||
.pyre/
|
||||
|
||||
# pytype
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# Custom
|
||||
.claude
|
||||
.opencode
|
||||
.codex
|
||||
openspec/changes/archive
|
||||
temp
|
||||
.agents
|
||||
skills-lock.json
|
||||
skills-lock.json
|
||||
.worktrees
|
||||
!scripts/build/
|
||||
|
||||
# Embedfs generated
|
||||
embedfs/assets/
|
||||
embedfs/frontend-dist/
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"files.eol": "\n"
|
||||
}
|
||||
253
Makefile
Normal file
253
Makefile
Normal file
@@ -0,0 +1,253 @@
|
||||
.PHONY: all dev build test lint clean \
|
||||
backend-build backend-run backend-dev backend-test backend-test-all backend-test-unit backend-test-integration backend-test-coverage \
|
||||
backend-lint backend-clean backend-deps backend-generate \
|
||||
backend-db-up backend-db-down backend-db-status backend-db-create \
|
||||
test-mysql-up test-mysql-down test-mysql test-mysql-quick \
|
||||
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \
|
||||
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \
|
||||
desktop-dev desktop-test desktop-clean \
|
||||
desktop-prepare-frontend desktop-prepare-embedfs
|
||||
|
||||
# ============================================
|
||||
# 顶层便捷命令
|
||||
# ============================================
|
||||
|
||||
dev:
|
||||
@echo "🚀 Starting development environment..."
|
||||
@$(MAKE) -j2 backend-dev frontend-dev
|
||||
|
||||
build: backend-build frontend-build
|
||||
@echo "✅ Build complete"
|
||||
|
||||
test: backend-test desktop-test frontend-test
|
||||
@echo "✅ All tests passed"
|
||||
|
||||
lint: backend-lint frontend-lint
|
||||
@echo "✅ Lint complete"
|
||||
|
||||
all: build test lint
|
||||
|
||||
# ============================================
|
||||
# 后端
|
||||
# ============================================
|
||||
|
||||
backend-build:
|
||||
cd backend && go build -o bin/server ./cmd/server
|
||||
|
||||
backend-run:
|
||||
cd backend && go run ./cmd/server
|
||||
|
||||
backend-dev:
|
||||
cd backend && go run ./cmd/server
|
||||
|
||||
backend-test:
|
||||
cd backend && go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
|
||||
|
||||
backend-test-all:
|
||||
cd backend && go test ./... -v
|
||||
|
||||
backend-test-unit:
|
||||
cd backend && go test ./internal/... ./pkg/... -v
|
||||
|
||||
backend-test-integration:
|
||||
cd backend && go test ./tests/... -v
|
||||
|
||||
backend-test-coverage:
|
||||
cd backend && go test ./... -coverprofile=coverage.out
|
||||
cd backend && go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: backend/coverage.html"
|
||||
|
||||
backend-lint:
|
||||
cd backend && go tool golangci-lint run ./...
|
||||
|
||||
backend-clean:
|
||||
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
|
||||
|
||||
backend-deps:
|
||||
cd backend && go mod tidy
|
||||
|
||||
backend-generate:
|
||||
cd backend && go generate ./...
|
||||
|
||||
backend-db-up:
|
||||
@echo "Running database migration up..."
|
||||
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" up
|
||||
|
||||
backend-db-down:
|
||||
@echo "Running database migration down..."
|
||||
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" down
|
||||
|
||||
backend-db-status:
|
||||
@echo "Checking database migration status..."
|
||||
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" status
|
||||
|
||||
backend-db-create:
|
||||
@read -p "Migration name: " name; \
|
||||
cd backend && goose -dir migrations/sqlite create $$name sql; \
|
||||
cd backend && goose -dir migrations/mysql create $$name sql
|
||||
|
||||
# ============================================
|
||||
# MySQL 专项测试
|
||||
# ============================================
|
||||
|
||||
test-mysql-up:
|
||||
@echo "Starting MySQL test container..."
|
||||
cd backend/tests/mysql && docker-compose up -d
|
||||
@echo "Waiting for MySQL to be ready..."
|
||||
@for i in $$(seq 1 30); do \
|
||||
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
|
||||
echo "MySQL is ready!"; \
|
||||
exit 0; \
|
||||
fi; \
|
||||
echo "Waiting... ($$i/30)"; \
|
||||
sleep 1; \
|
||||
done; \
|
||||
echo "MySQL failed to start"; \
|
||||
exit 1
|
||||
|
||||
test-mysql-down:
|
||||
@echo "Stopping MySQL test container..."
|
||||
cd backend/tests/mysql && docker-compose down -v
|
||||
|
||||
test-mysql: test-mysql-up
|
||||
@echo "Running MySQL tests..."
|
||||
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
$(MAKE) test-mysql-down
|
||||
|
||||
test-mysql-quick:
|
||||
@echo "Running MySQL tests (without container management)..."
|
||||
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
|
||||
# ============================================
|
||||
# 前端
|
||||
# ============================================
|
||||
|
||||
frontend-install:
|
||||
cd frontend && bun install
|
||||
|
||||
frontend-build: frontend-install
|
||||
cd frontend && bun run build
|
||||
|
||||
frontend-dev: frontend-install
|
||||
cd frontend && bun dev
|
||||
|
||||
frontend-test: frontend-install
|
||||
cd frontend && bun run test
|
||||
|
||||
frontend-test-watch: frontend-install
|
||||
cd frontend && bun run test:watch
|
||||
|
||||
frontend-test-coverage: frontend-install
|
||||
cd frontend && bun run test:coverage
|
||||
|
||||
frontend-test-e2e: frontend-install
|
||||
cd frontend && bun run test:e2e
|
||||
|
||||
frontend-lint: frontend-install
|
||||
cd frontend && bun run lint
|
||||
|
||||
frontend-clean:
|
||||
rm -rf frontend/dist frontend/.next frontend/node_modules frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo
|
||||
|
||||
# ============================================
|
||||
# 桌面应用
|
||||
# ============================================
|
||||
|
||||
desktop-build: desktop-build-mac desktop-build-win desktop-build-linux
|
||||
@echo "✅ Desktop builds complete for all platforms"
|
||||
|
||||
desktop-prepare-frontend:
|
||||
@echo "📦 Preparing frontend for desktop..."
|
||||
cd frontend && cp .env.desktop .env.production.local
|
||||
cd frontend && bun install && bun run build
|
||||
rm -f frontend/.env.production.local
|
||||
|
||||
desktop-prepare-embedfs:
|
||||
@echo "📦 Preparing embedded filesystem..."
|
||||
rm -rf embedfs/assets embedfs/frontend-dist
|
||||
cp -r assets embedfs/assets
|
||||
cp -r frontend/dist embedfs/frontend-dist
|
||||
|
||||
desktop-build-mac: desktop-prepare-frontend desktop-prepare-embedfs
|
||||
@echo "🍎 Building macOS..."
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-mac-arm64 ./cmd/desktop
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-mac-amd64 ./cmd/desktop
|
||||
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
|
||||
@echo "📦 Packaging macOS .app..."
|
||||
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
|
||||
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
|
||||
@if [ -f assets/AppIcon.icns ]; then \
|
||||
cp assets/AppIcon.icns build/Nex.app/Contents/Resources/; \
|
||||
else \
|
||||
echo "⚠️ 未找到 assets/AppIcon.icns"; \
|
||||
fi
|
||||
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
|
||||
if [ -z "$$MIN_MACOS_VERSION" ]; then \
|
||||
echo "❌ 无法读取 macOS 最低系统版本"; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
{ \
|
||||
printf '%s\n' '<?xml version="1.0" encoding="UTF-8"?>' \
|
||||
'<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">' \
|
||||
'<plist version="1.0">' \
|
||||
'<dict>' \
|
||||
' <key>CFBundleDevelopmentRegion</key>' \
|
||||
' <string>zh-Hans</string>' \
|
||||
' <key>CFBundleExecutable</key>' \
|
||||
' <string>nex</string>' \
|
||||
' <key>CFBundleIconFile</key>' \
|
||||
' <string>AppIcon</string>' \
|
||||
' <key>CFBundleIdentifier</key>' \
|
||||
' <string>com.lanyuanxiaoyao.nex</string>' \
|
||||
' <key>CFBundleInfoDictionaryVersion</key>' \
|
||||
' <string>6.0</string>' \
|
||||
' <key>LSApplicationCategoryType</key>' \
|
||||
' <string>public.app-category.developer-tools</string>' \
|
||||
' <key>CFBundleName</key>' \
|
||||
' <string>Nex</string>' \
|
||||
' <key>CFBundleDisplayName</key>' \
|
||||
' <string>Nex</string>' \
|
||||
' <key>CFBundlePackageType</key>' \
|
||||
' <string>APPL</string>' \
|
||||
' <key>CFBundleShortVersionString</key>' \
|
||||
' <string>1.0.0</string>' \
|
||||
' <key>CFBundleVersion</key>' \
|
||||
' <string>1.0.0</string>' \
|
||||
' <key>NSHumanReadableCopyright</key>' \
|
||||
' <string>Copyright © 2026 Nex</string>' \
|
||||
' <key>LSMinimumSystemVersion</key>' \
|
||||
" <string>$$MIN_MACOS_VERSION</string>" \
|
||||
' <key>LSUIElement</key>' \
|
||||
' <true/>' \
|
||||
' <key>NSHighResolutionCapable</key>' \
|
||||
' <true/>' \
|
||||
'</dict>' \
|
||||
'</plist>'; \
|
||||
} > build/Nex.app/Contents/Info.plist
|
||||
chmod +x build/Nex.app/Contents/MacOS/nex
|
||||
@echo "✅ macOS app packaged: build/Nex.app"
|
||||
|
||||
desktop-build-win: desktop-prepare-frontend desktop-prepare-embedfs
|
||||
@echo "🪟 Building Windows..."
|
||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-win-amd64.exe ./cmd/desktop
|
||||
|
||||
desktop-build-linux: desktop-prepare-frontend desktop-prepare-embedfs
|
||||
@echo "🐧 Building Linux..."
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop
|
||||
|
||||
desktop-dev: desktop-prepare-frontend desktop-prepare-embedfs
|
||||
@echo "🖥️ Starting desktop app in dev mode..."
|
||||
cd backend && go run ./cmd/desktop
|
||||
|
||||
desktop-test:
|
||||
cd backend && go test ./cmd/desktop/... -v
|
||||
|
||||
desktop-clean:
|
||||
rm -rf build/ embedfs/assets embedfs/frontend-dist
|
||||
|
||||
# ============================================
|
||||
# 清理
|
||||
# ============================================
|
||||
|
||||
clean: backend-clean frontend-clean desktop-clean
|
||||
@echo "✅ Clean complete"
|
||||
213
README.md
213
README.md
@@ -7,13 +7,15 @@
|
||||
```
|
||||
nex/
|
||||
├── backend/ # Go 后端服务(分层架构)
|
||||
│ ├── cmd/server/ # 主程序入口
|
||||
│ ├── cmd/
|
||||
│ │ ├── server/ # CLI 主程序入口
|
||||
│ │ └── desktop/ # 桌面应用入口
|
||||
│ ├── internal/
|
||||
│ │ ├── handler/ # HTTP 处理器 + 中间件
|
||||
│ │ ├── service/ # 业务逻辑层
|
||||
│ │ ├── repository/ # 数据访问层
|
||||
│ │ ├── domain/ # 领域模型
|
||||
│ │ ├── protocol/ # 协议适配器(OpenAI/Anthropic)
|
||||
│ │ ├── conversion/ # 协议转换引擎(OpenAI/Anthropic 适配器)
|
||||
│ │ ├── provider/ # 供应商客户端
|
||||
│ │ └── config/ # 配置管理
|
||||
│ ├── pkg/ # 公共包(logger/errors/validator)
|
||||
@@ -32,16 +34,25 @@ nex/
|
||||
│ ├── e2e/ # Playwright E2E 测试
|
||||
│ └── package.json
|
||||
│
|
||||
├── assets/ # 应用资源
|
||||
│ ├── icon.png # 托盘图标
|
||||
│ ├── AppIcon.icns # macOS 应用图标
|
||||
│ └── icon.ico # Windows 应用图标
|
||||
│
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
||||
- **透明代理**:对 OpenAI 兼容供应商透传请求
|
||||
- **流式响应**:完整支持 SSE 流式传输
|
||||
- **跨协议转换**:Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
|
||||
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
||||
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段
|
||||
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
|
||||
- **Function Calling**:支持工具调用(Tools)
|
||||
- **多供应商管理**:配置和管理多个供应商
|
||||
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
|
||||
- **扩展接口**:支持 Embeddings 和 Rerank 接口
|
||||
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
|
||||
- **用量统计**:按供应商、模型、日期统计请求数量
|
||||
- **Web 配置界面**:提供供应商和模型配置管理
|
||||
|
||||
@@ -51,12 +62,26 @@ nex/
|
||||
- **语言**: Go 1.26+
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转)
|
||||
- **配置**: gopkg.in/yaml.v3
|
||||
- **数据库**: SQLite / MySQL
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转 + 模块标识)
|
||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
#### 日志模块标识规范
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger,日志输出格式为 `[module.name]`:
|
||||
|
||||
```
|
||||
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
|
||||
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
|
||||
```
|
||||
|
||||
模块命名规范:
|
||||
- 单一职责包:`database`、`config`
|
||||
- 多实体包:`handler.proxy`、`service.provider`
|
||||
- 子包:`handler.middleware`
|
||||
|
||||
### 前端
|
||||
- **运行时**: Bun
|
||||
- **构建工具**: Vite
|
||||
@@ -71,7 +96,45 @@ nex/
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 后端
|
||||
### 桌面应用(推荐)
|
||||
|
||||
**构建桌面应用**:
|
||||
|
||||
```bash
|
||||
# macOS (arm64 + amd64,并打包为 .app)
|
||||
make desktop-build-mac
|
||||
|
||||
# Windows
|
||||
make desktop-build-win
|
||||
|
||||
# Linux
|
||||
make desktop-build-linux
|
||||
|
||||
# 构建所有平台
|
||||
make desktop-build
|
||||
```
|
||||
|
||||
**使用桌面应用**:
|
||||
- 双击启动应用(macOS: Nex.app,Windows: nex-win-amd64.exe,Linux: nex-linux-amd64)
|
||||
- 系统托盘图标出现,浏览器自动打开管理界面
|
||||
- 点击托盘图标显示菜单,可打开管理界面或退出
|
||||
- 关闭浏览器后服务继续运行,可通过托盘重新打开
|
||||
|
||||
**注意事项**:
|
||||
- 桌面应用需要 CGO 支持
|
||||
- macOS: 自带 Xcode Command Line Tools
|
||||
- Linux: 自带 gcc,部分桌面环境需要 `libappindicator3-dev`
|
||||
- Windows: 需要 MinGW-w64 或在 Windows 环境构建
|
||||
|
||||
**Linux 桌面环境兼容性**:
|
||||
- GNOME: 需要 AppIndicator 扩展
|
||||
- KDE Plasma: 原生支持
|
||||
- Xfce: 需要 libappindicator
|
||||
- 其他支持 StatusNotifierItem 规范的环境
|
||||
|
||||
### CLI 模式
|
||||
|
||||
#### 后端
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
@@ -99,23 +162,34 @@ bun dev
|
||||
|
||||
### 代理接口(对外部应用)
|
||||
|
||||
- `POST /v1/chat/completions` - OpenAI Chat Completions API
|
||||
- `POST /v1/messages` - Anthropic Messages API
|
||||
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
|
||||
|
||||
**OpenAI 协议**(`protocol=openai`):
|
||||
- `POST /openai/chat/completions` - 对话补全
|
||||
- `GET /openai/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
- `POST /openai/embeddings` - 嵌入
|
||||
- `POST /openai/rerank` - 重排序
|
||||
|
||||
**Anthropic 协议**(`protocol=anthropic`):
|
||||
- `POST /anthropic/v1/messages` - 消息对话
|
||||
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
|
||||
### 管理接口(对前端)
|
||||
|
||||
#### 供应商管理
|
||||
- `GET /api/providers` - 列出所有供应商
|
||||
- `POST /api/providers` - 创建供应商
|
||||
- `POST /api/providers` - 创建供应商(`id` 仅限字母、数字、下划线,长度 1-64)
|
||||
- `GET /api/providers/:id` - 获取供应商
|
||||
- `PUT /api/providers/:id` - 更新供应商
|
||||
- `PUT /api/providers/:id` - 更新供应商(`id` 不可修改)
|
||||
- `DELETE /api/providers/:id` - 删除供应商
|
||||
|
||||
#### 模型管理
|
||||
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
|
||||
- `POST /api/models` - 创建模型
|
||||
- `GET /api/models/:id` - 获取模型
|
||||
- `PUT /api/models/:id` - 更新模型
|
||||
- `POST /api/models` - 创建模型(`id` 由系统自动生成 UUID,`provider_id` + `model_name` 联合唯一)
|
||||
- `GET /api/models/:id` - 获取模型(响应含 `unified_id` 字段,格式 `provider_id/model_name`)
|
||||
- `PUT /api/models/:id` - 更新模型(不可修改 `id`)
|
||||
- `DELETE /api/models/:id` - 删除模型
|
||||
|
||||
#### 统计查询
|
||||
@@ -126,6 +200,10 @@ bun dev
|
||||
|
||||
## 配置
|
||||
|
||||
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成:
|
||||
|
||||
```yaml
|
||||
@@ -135,7 +213,14 @@ server:
|
||||
write_timeout: 30s
|
||||
|
||||
database:
|
||||
path: ~/.nex/config.db
|
||||
driver: sqlite # sqlite 或 mysql
|
||||
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||
# --- MySQL 配置(driver=mysql 时生效)---
|
||||
# host: localhost
|
||||
# port: 3306
|
||||
# user: nex
|
||||
# password: ""
|
||||
# dbname: nex
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
conn_max_lifetime: 1h
|
||||
@@ -149,48 +234,88 @@ log:
|
||||
compress: true
|
||||
```
|
||||
|
||||
数据文件:
|
||||
### 环境变量
|
||||
|
||||
所有配置项支持环境变量,使用 `NEX_` 前缀:
|
||||
|
||||
```bash
|
||||
export NEX_SERVER_PORT=9000
|
||||
export NEX_DATABASE_PATH=/data/nex.db
|
||||
export NEX_LOG_LEVEL=debug
|
||||
|
||||
# MySQL 模式
|
||||
export NEX_DATABASE_DRIVER=mysql
|
||||
export NEX_DATABASE_HOST=db.example.com
|
||||
export NEX_DATABASE_PORT=3306
|
||||
export NEX_DATABASE_USER=nex
|
||||
export NEX_DATABASE_PASSWORD=secret
|
||||
export NEX_DATABASE_DBNAME=nex
|
||||
```
|
||||
|
||||
命名规则:配置路径转大写 + 下划线(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||
|
||||
### CLI 参数
|
||||
|
||||
```bash
|
||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||
```
|
||||
|
||||
命名规则:配置路径转 kebab-case(如 `server.port` → `--server-port`)。
|
||||
|
||||
### 数据文件
|
||||
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
## 测试
|
||||
|
||||
### 后端测试
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
make test # 运行所有测试
|
||||
make test-coverage # 生成覆盖率报告
|
||||
```
|
||||
# 顶层便捷命令
|
||||
make test # 运行所有测试
|
||||
|
||||
### 前端测试
|
||||
# 后端测试
|
||||
make backend-test # 后端测试
|
||||
make backend-test-coverage # 后端覆盖率
|
||||
make backend-test-unit # 后端单元测试
|
||||
make backend-test-integration # 后端集成测试
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun run test # 单元测试 + 组件测试
|
||||
bun run test:watch # 监听模式
|
||||
bun run test:coverage # 生成覆盖率报告
|
||||
bun run test:e2e # E2E 测试
|
||||
# 前端测试
|
||||
make frontend-test # 前端测试
|
||||
make frontend-test-e2e # 前端 E2E 测试
|
||||
make frontend-test-coverage # 前端覆盖率
|
||||
```
|
||||
|
||||
## 开发
|
||||
|
||||
### 后端开发
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
make build # 构建
|
||||
make lint # 代码检查
|
||||
make migrate-up # 数据库迁移
|
||||
```
|
||||
# 首次克隆后安装 Git hooks
|
||||
lefthook install
|
||||
|
||||
### 前端开发
|
||||
# 顶层便捷命令
|
||||
make dev # 启动开发环境(并行启动后端和前端)
|
||||
make build # 构建所有产物
|
||||
make lint # 检查所有代码
|
||||
make clean # 清理所有构建产物
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun run build # 构建生产版本
|
||||
bun run lint # 代码检查
|
||||
# 后端开发
|
||||
make backend-build # 构建后端
|
||||
make backend-run # 运行后端
|
||||
make backend-dev # 后端开发模式
|
||||
make backend-lint # 后端代码检查
|
||||
make backend-clean # 清理后端构建产物
|
||||
|
||||
# 数据库操作
|
||||
make backend-db-up # 数据库迁移
|
||||
make backend-db-down # 数据库回滚
|
||||
make backend-db-status # 数据库迁移状态
|
||||
make backend-db-create # 创建新迁移
|
||||
|
||||
# 前端开发
|
||||
make frontend-build # 构建前端
|
||||
make frontend-dev # 前端开发模式
|
||||
make frontend-lint # 前端代码检查
|
||||
make frontend-clean # 清理前端构建产物
|
||||
```
|
||||
|
||||
## 开发规范
|
||||
|
||||
BIN
assets/AppIcon.icns
Normal file
BIN
assets/AppIcon.icns
Normal file
Binary file not shown.
64
assets/README.md
Normal file
64
assets/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Assets
|
||||
|
||||
应用资源文件目录。
|
||||
|
||||
## 文件说明
|
||||
|
||||
| 文件 | 用途 | 尺寸 | 格式 |
|
||||
|------|------|------|------|
|
||||
| `icon.svg` | 源图标 | 64x64 | SVG |
|
||||
| `icon.png` | 托盘图标 | 64x64 | PNG |
|
||||
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
|
||||
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
|
||||
|
||||
## 替换图标
|
||||
|
||||
### 1. 准备图标
|
||||
|
||||
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
|
||||
|
||||
### 2. 生成各平台图标
|
||||
|
||||
**托盘图标 (PNG)**:
|
||||
```bash
|
||||
magick your-icon.svg -resize 64x64 icon.png
|
||||
```
|
||||
|
||||
**macOS 应用图标 (ICNS)**:
|
||||
```bash
|
||||
mkdir icon.iconset
|
||||
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
|
||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
|
||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
|
||||
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
|
||||
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
|
||||
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
|
||||
iconutil -c icns icon.iconset -o AppIcon.icns
|
||||
rm -rf icon.iconset
|
||||
```
|
||||
|
||||
**Windows 应用图标 (ICO)**:
|
||||
```bash
|
||||
magick your-icon.svg -resize 256x256 icon.ico
|
||||
```
|
||||
|
||||
### 3. 替换文件
|
||||
|
||||
将生成的文件放入此目录,然后重新构建桌面应用:
|
||||
```bash
|
||||
./scripts/build/build-darwin-arm64.sh
|
||||
```
|
||||
|
||||
## macOS Template 图标
|
||||
|
||||
macOS 支持 Template 图标,自动适配深浅色模式:
|
||||
- 使用黑色 + 透明设计
|
||||
- 文件名以 `Template` 结尾(如 `iconTemplate.png`)
|
||||
- 黑色在深色模式下自动变为白色
|
||||
|
||||
## 设计建议
|
||||
|
||||
- 托盘图标应简洁,在小尺寸下清晰可辨
|
||||
- 避免过多细节和文字
|
||||
- 使用高对比度颜色
|
||||
- macOS 建议使用 Template 图标风格
|
||||
BIN
assets/icon.ico
Normal file
BIN
assets/icon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 264 KiB |
BIN
assets/icon.png
Normal file
BIN
assets/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.0 KiB |
13
assets/icon.svg
Normal file
13
assets/icon.svg
Normal file
@@ -0,0 +1,13 @@
|
||||
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
|
||||
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
|
||||
<circle cx="32" cy="32" r="6" fill="white"/>
|
||||
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
|
||||
<circle cx="20" cy="20" r="3" fill="white"/>
|
||||
<circle cx="44" cy="20" r="3" fill="white"/>
|
||||
<circle cx="20" cy="44" r="3" fill="white"/>
|
||||
<circle cx="44" cy="44" r="3" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 779 B |
91
backend/.golangci.yml
Normal file
91
backend/.golangci.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- forbidigo
|
||||
- errorlint
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- revive
|
||||
- gocritic
|
||||
- gosec
|
||||
- bodyclose
|
||||
- noctx
|
||||
- nilerr
|
||||
- goimports
|
||||
- gocyclo
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-blank: true
|
||||
check-type-assertions: true
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
forbidigo:
|
||||
analyze-types: true
|
||||
forbid:
|
||||
- p: '^fmt\.Print.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^fmt\.Fprint.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
|
||||
msg: 使用 zap logger,不要使用标准库 log
|
||||
- p: '^zap\.L$'
|
||||
msg: 通过依赖注入传递 *zap.Logger,不要使用全局 logger
|
||||
- p: '^zap\.S$'
|
||||
msg: 不使用 Sugar logger
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
- name: var-naming
|
||||
- name: indent-error-flow
|
||||
- name: error-strings
|
||||
- name: error-return
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: unexported-return
|
||||
goimports:
|
||||
local-prefixes: nex/backend
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
|
||||
issues:
|
||||
exclude-dirs:
|
||||
- tests/mocks
|
||||
exclude-generated: true
|
||||
exclude-rules:
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- forbidigo
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- errcheck
|
||||
source: '(^\s*_\s*=|,\s*_)'
|
||||
- path: 'tests/integration/e2e_conversion_test\.go'
|
||||
linters:
|
||||
- errcheck
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- revive
|
||||
text: '^exported:'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gosec
|
||||
text: 'G(101|401|501)'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gocyclo
|
||||
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
|
||||
- linters:
|
||||
- revive
|
||||
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
|
||||
- path: 'internal/conversion/.*\.go'
|
||||
linters:
|
||||
- gocyclo
|
||||
- gocritic
|
||||
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
|
||||
linters:
|
||||
- gocyclo
|
||||
@@ -1,45 +0,0 @@
|
||||
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint
|
||||
|
||||
# 构建
|
||||
build:
|
||||
go build -o bin/server ./cmd/server
|
||||
|
||||
# 运行
|
||||
run:
|
||||
go run ./cmd/server
|
||||
|
||||
# 测试
|
||||
test:
|
||||
go test ./... -v
|
||||
|
||||
# 测试覆盖率
|
||||
test-coverage:
|
||||
go test ./... -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: coverage.html"
|
||||
|
||||
# 清理
|
||||
clean:
|
||||
rm -rf bin/ coverage.out coverage.html
|
||||
|
||||
# 数据库迁移
|
||||
migrate-up:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) up
|
||||
|
||||
migrate-down:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) down
|
||||
|
||||
migrate-status:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) status
|
||||
|
||||
migrate-create:
|
||||
@read -p "Migration name: " name; \
|
||||
goose -dir migrations create $$name sql
|
||||
|
||||
# 代码检查
|
||||
lint:
|
||||
golangci-lint run ./...
|
||||
|
||||
# 安装依赖
|
||||
deps:
|
||||
go mod tidy
|
||||
@@ -14,19 +14,65 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
- 支持扩展层接口(Models、Embeddings、Rerank)
|
||||
- 多供应商配置和路由
|
||||
- 用量统计
|
||||
- 结构化日志(zap + lumberjack)
|
||||
- 结构化日志(zap + lumberjack + 模块标识)
|
||||
- YAML 配置管理
|
||||
- 请求验证
|
||||
- 中间件支持(请求 ID、日志、恢复、CORS)
|
||||
|
||||
## 日志规范
|
||||
|
||||
### 模块标识
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger:
|
||||
|
||||
```go
|
||||
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
输出格式:
|
||||
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
|
||||
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
|
||||
|
||||
### 模块命名规范
|
||||
|
||||
| 模块 | 命名 |
|
||||
|------|------|
|
||||
| ProxyHandler | `handler.proxy` |
|
||||
| ProviderHandler | `handler.provider` |
|
||||
| Provider Client | `provider.client` |
|
||||
| ConversionEngine | `conversion.engine` |
|
||||
| RoutingCache | `service.routing_cache` |
|
||||
| StatsBuffer | `service.stats_buffer` |
|
||||
| Database | `database` |
|
||||
|
||||
### 标准字段
|
||||
|
||||
使用 `pkg/logger/field.go` 中定义的字段构造函数:
|
||||
|
||||
```go
|
||||
logger.Info("请求开始",
|
||||
pkglogger.Method("POST"),
|
||||
pkglogger.Path("/v1/chat"),
|
||||
pkglogger.RequestID("xxx"),
|
||||
)
|
||||
```
|
||||
|
||||
### GORM 日志
|
||||
|
||||
GORM 日志自动桥接到 zap,SQL 查询映射到 Debug 级别。
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **语言**: Go 1.26+
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
- **数据库**: SQLite / MySQL
|
||||
- **日志**: zap + lumberjack
|
||||
- **配置**: gopkg.in/yaml.v3
|
||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
@@ -39,7 +85,7 @@ backend/
|
||||
│ └── main.go # 主程序入口(依赖注入)
|
||||
├── internal/
|
||||
│ ├── config/ # 配置管理
|
||||
│ │ ├── config.go # 配置加载/保存/验证
|
||||
│ │ ├── config.go # Viper 多层配置加载/验证
|
||||
│ │ └── models.go # GORM 数据模型
|
||||
│ ├── domain/ # 领域模型
|
||||
│ │ ├── provider.go
|
||||
@@ -47,17 +93,17 @@ backend/
|
||||
│ │ ├── stats.go
|
||||
│ │ └── route.go
|
||||
│ ├── handler/ # HTTP 处理器
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ │ ├── request_id.go
|
||||
│ │ │ ├── logging.go
|
||||
│ │ │ ├── recovery.go
|
||||
│ │ │ └── cors.go
|
||||
│ │ ├── proxy_handler.go # 统一代理处理器
|
||||
│ │ ├── proxy_handler.go # 统一代理处理器
|
||||
│ │ ├── provider_handler.go
|
||||
│ │ ├── model_handler.go
|
||||
│ │ └── stats_handler.go
|
||||
│ ├── conversion/ # 协议转换引擎
|
||||
│ │ ├── canonical/ # Canonical Model
|
||||
│ ├── conversion/ # 协议转换引擎
|
||||
│ │ ├── canonical/ # Canonical Model
|
||||
│ │ │ ├── types.go # 核心请求/响应类型
|
||||
│ │ │ ├── stream.go # 流式事件类型
|
||||
│ │ │ └── extended.go # 扩展层 Models
|
||||
@@ -105,19 +151,26 @@ backend/
|
||||
│ │ ├── errors.go
|
||||
│ │ └── wrap.go
|
||||
│ ├── logger/ # 日志系统
|
||||
│ │ ├── logger.go
|
||||
│ │ ├── rotate.go
|
||||
│ │ └── context.go
|
||||
│ │ ├── logger.go # 核心初始化
|
||||
│ │ ├── field.go # 标准字段定义
|
||||
│ │ ├── module.go # 模块日志器
|
||||
│ │ ├── context.go # Context 辅助函数
|
||||
│ │ ├── gorm.go # GORM 适配器
|
||||
│ │ ├── minimal.go # 最小化 logger
|
||||
│ │ └── rotate.go # 日志轮转
|
||||
│ ├── modelid/ # 统一模型 ID 工具包
|
||||
│ │ ├── model_id.go
|
||||
│ │ └── model_id_test.go
|
||||
│ └── validator/ # 验证器
|
||||
│ └── validator.go
|
||||
├── migrations/ # 数据库迁移
|
||||
│ ├── 001_initial_schema.sql
|
||||
│ └── 002_add_indexes.sql
|
||||
├── tests/ # 测试
|
||||
│ ├── helpers.go
|
||||
│ ├── integration/
|
||||
│ ├── unit/
|
||||
│ └── testdata/
|
||||
│ └── 20260421000001_initial_schema.sql
|
||||
├── tests/ # 集成测试
|
||||
│ ├── helpers.go # 测试辅助函数
|
||||
│ ├── config/ # 测试配置
|
||||
│ ├── integration/ # 集成测试
|
||||
│ │ └── e2e_conversion_test.go # E2E 协议转换测试
|
||||
│ └── mocks/ # Mock 实现
|
||||
├── Makefile
|
||||
├── go.mod
|
||||
└── README.md
|
||||
@@ -133,6 +186,133 @@ handler(HTTP 请求处理)
|
||||
→ repository(数据访问)
|
||||
```
|
||||
|
||||
代理请求通过 ConversionEngine 进行协议转换:
|
||||
|
||||
```
|
||||
Client Request (clientProtocol)
|
||||
→ ProxyHandler 路由到上游 provider
|
||||
→ ConversionEngine 请求转换 (clientProtocol → providerProtocol)
|
||||
→ ProviderClient 发送请求
|
||||
→ ConversionEngine 响应转换 (providerProtocol → clientProtocol)
|
||||
→ Client Response
|
||||
```
|
||||
|
||||
同协议时自动透传,跳过序列化开销。
|
||||
|
||||
## 协议转换架构
|
||||
|
||||
### Canonical Model 中间表示
|
||||
|
||||
所有协议转换都经过 Canonical Model 中间表示层,实现 Hub-and-Spoke 架构:
|
||||
|
||||
```
|
||||
OpenAI Request → Canonical Request → Anthropic Request
|
||||
(中间表示)
|
||||
OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
```
|
||||
|
||||
**CanonicalRequest 核心字段**:
|
||||
- `Model` - 统一模型 ID
|
||||
- `Messages` - 消息列表(支持 text、tool_use、tool_result、thinking 类型)
|
||||
- `Tools` - 工具定义
|
||||
- `Thinking` - 推理配置(`budget_tokens`、`effort`)
|
||||
- `Parameters` - 通用参数(`max_tokens`、`temperature`、`top_p` 等)
|
||||
|
||||
### Smart Passthrough 机制
|
||||
|
||||
同协议请求走 Smart Passthrough 路径,**零序列化开销**:
|
||||
|
||||
```
|
||||
1. 检测 clientProtocol == providerProtocol
|
||||
2. 仅改写请求体中的 model 字段:unified_id → upstream_model_name
|
||||
3. 直接转发请求到上游
|
||||
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
||||
```
|
||||
|
||||
### 流式转换器层次
|
||||
|
||||
```
|
||||
StreamConverter (接口)
|
||||
├── PassthroughStreamConverter # 直接透传,无任何处理
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model
|
||||
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
||||
```
|
||||
|
||||
### InterfaceType 枚举
|
||||
|
||||
| 类型 | 说明 |
|
||||
|------|------|
|
||||
| `CHAT` | 对话补全(chat/completions、messages) |
|
||||
| `MODELS` | 模型列表 |
|
||||
| `MODEL_INFO` | 模型详情 |
|
||||
| `EMBEDDINGS` | 嵌入接口 |
|
||||
| `RERANK` | 重排序接口 |
|
||||
| `PASSTHROUGH` | 未知接口,直接透传 |
|
||||
|
||||
## 协议适配器特性
|
||||
|
||||
### OpenAI 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `reasoning_effort` - 映射到 Canonical Thinking 配置(`none` → 禁用,其他 → `effort`)
|
||||
- `reasoning_content` - 非标准字段,映射到 Canonical thinking 块
|
||||
- `max_completion_tokens` - 新字段,优先于 `max_tokens`
|
||||
- `refusal` - 非标准字段,作为 text 块处理
|
||||
|
||||
**废弃字段兼容**:
|
||||
- `functions` / `function_call` - 自动转换为 `tools` / `tool_choice`
|
||||
|
||||
**消息处理**:
|
||||
- 合并连续同角色消息(Anthropic 不支持连续同角色)
|
||||
- 工具选择映射:`any` → `required`
|
||||
|
||||
### Anthropic 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `thinking` - 推理配置(`type: enabled`、`budget_tokens`、`effort`)
|
||||
- `output_config` - 结构化输出配置
|
||||
- `disable_parallel_tool_use` - 禁用并行工具调用
|
||||
- `container` - 工具容器字段
|
||||
|
||||
**不支持的功能**:
|
||||
- Embeddings 接口(返回 `INTERFACE_NOT_SUPPORTED` 错误)
|
||||
|
||||
### 跨协议转换注意事项
|
||||
|
||||
| 源协议 | 目标协议 | 转换说明 |
|
||||
|--------|----------|----------|
|
||||
| OpenAI | Anthropic | `reasoning_effort` → `thinking`,消息角色合并 |
|
||||
| Anthropic | OpenAI | `thinking` 块 → `reasoning_content`,工具选择转换 |
|
||||
|
||||
## 错误码
|
||||
|
||||
### ConversionError 错误码
|
||||
|
||||
| 错误码 | 说明 |
|
||||
|--------|------|
|
||||
| `INVALID_INPUT` | 输入数据无效 |
|
||||
| `MISSING_REQUIRED_FIELD` | 缺少必填字段 |
|
||||
| `INCOMPATIBLE_FEATURE` | 功能不兼容(如跨协议不支持某特性) |
|
||||
| `FIELD_MAPPING_FAILURE` | 字段映射失败 |
|
||||
| `TOOL_CALL_PARSE_ERROR` | 工具调用解析错误 |
|
||||
| `JSON_PARSE_ERROR` | JSON 解析错误 |
|
||||
| `STREAM_STATE_ERROR` | 流式状态错误 |
|
||||
| `UTF8_DECODE_ERROR` | UTF-8 解码错误(流式 chunk 截断) |
|
||||
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
||||
| `ENCODING_FAILURE` | 编码失败 |
|
||||
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
||||
|
||||
### AppError 预定义错误
|
||||
|
||||
| 错误 | HTTP 状态码 | 说明 |
|
||||
|------|-------------|------|
|
||||
| `ErrModelNotFound` | 404 | 模型未找到 |
|
||||
| `ErrModelDisabled` | 404 | 模型已禁用 |
|
||||
| `ErrProviderNotFound` | 404 | 供应商未找到 |
|
||||
| `ErrInvalidProviderID` | 400 | 供应商 ID 格式无效 |
|
||||
| `ErrDuplicateModel` | 409 | 同一供应商下模型名称重复 |
|
||||
| `ErrImmutableField` | 400 | 不可修改字段(如供应商 ID) |
|
||||
|
||||
## 运行方式
|
||||
|
||||
### 安装依赖
|
||||
@@ -151,6 +331,10 @@ go run cmd/server/main.go
|
||||
|
||||
## 配置
|
||||
|
||||
配置支持多种方式:配置文件、环境变量、命令行参数,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成。
|
||||
|
||||
```yaml
|
||||
@@ -160,7 +344,14 @@ server:
|
||||
write_timeout: 30s
|
||||
|
||||
database:
|
||||
path: ~/.nex/config.db
|
||||
driver: sqlite # sqlite 或 mysql
|
||||
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||
# --- MySQL 配置(driver=mysql 时生效)---
|
||||
# host: localhost
|
||||
# port: 3306
|
||||
# user: nex
|
||||
# password: ""
|
||||
# dbname: nex
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
conn_max_lifetime: 1h
|
||||
@@ -174,11 +365,72 @@ log:
|
||||
compress: true
|
||||
```
|
||||
|
||||
### 环境变量
|
||||
|
||||
所有配置项都支持环境变量,使用 `NEX_` 前缀:
|
||||
|
||||
```bash
|
||||
export NEX_SERVER_PORT=9000
|
||||
export NEX_DATABASE_PATH=/data/nex.db
|
||||
export NEX_LOG_LEVEL=debug
|
||||
|
||||
# MySQL 模式
|
||||
export NEX_DATABASE_DRIVER=mysql
|
||||
export NEX_DATABASE_HOST=db.example.com
|
||||
export NEX_DATABASE_PORT=3306
|
||||
export NEX_DATABASE_USER=nex
|
||||
export NEX_DATABASE_PASSWORD=secret
|
||||
export NEX_DATABASE_DBNAME=nex
|
||||
```
|
||||
|
||||
命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||
|
||||
### 命令行参数
|
||||
|
||||
```bash
|
||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||
```
|
||||
|
||||
命名规则:配置路径转 kebab-case + `--` 前缀(如 `server.port` → `--server-port`)。
|
||||
|
||||
完整参数列表:
|
||||
|
||||
```
|
||||
服务器: --server-port, --server-read-timeout, --server-write-timeout
|
||||
数据库: --database-driver, --database-path, --database-host, --database-port, --database-user, --database-password, --database-dbname, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
|
||||
日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress
|
||||
通用: --config (指定配置文件路径)
|
||||
```
|
||||
|
||||
### 使用示例
|
||||
|
||||
```bash
|
||||
# 默认配置
|
||||
./server
|
||||
|
||||
# 临时修改端口
|
||||
./server --server-port 9000
|
||||
|
||||
# 测试场景
|
||||
./server --database-path /tmp/test.db --log-level debug
|
||||
|
||||
# Docker 部署
|
||||
docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
|
||||
|
||||
# MySQL 模式
|
||||
./server --database-driver mysql --database-host db.example.com --database-user nex --database-password secret --database-dbname nex
|
||||
|
||||
# 自定义配置文件
|
||||
./server --config /path/to/custom.yaml
|
||||
```
|
||||
|
||||
数据文件:
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
**MySQL 连接说明**:MySQL 连接使用 DSN 格式: `user:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=true&loc=Local`,最低支持 MySQL 8.0+。
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
@@ -197,6 +449,9 @@ make migrate-up DB_PATH=~/.nex/config.db
|
||||
make migrate-down DB_PATH=~/.nex/config.db
|
||||
make migrate-status DB_PATH=~/.nex/config.db
|
||||
|
||||
# 创建新迁移
|
||||
make migrate-create
|
||||
|
||||
# 或直接使用 goose
|
||||
goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
```
|
||||
@@ -207,48 +462,26 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
|
||||
使用 `/{protocol}/v1/{path}` URL 前缀路由:
|
||||
|
||||
#### OpenAI 协议代理
|
||||
#### OpenAI 协议
|
||||
|
||||
```
|
||||
POST /openai/v1/chat/completions
|
||||
GET /openai/v1/models
|
||||
POST /openai/v1/embeddings
|
||||
POST /openai/v1/rerank
|
||||
POST /openai/chat/completions
|
||||
GET /openai/models
|
||||
POST /openai/embeddings
|
||||
POST /openai/rerank
|
||||
```
|
||||
|
||||
请求示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
#### Anthropic 协议代理
|
||||
#### Anthropic 协议
|
||||
|
||||
```
|
||||
POST /anthropic/v1/messages
|
||||
GET /anthropic/v1/models
|
||||
```
|
||||
|
||||
请求示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello"}]}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
||||
|
||||
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
||||
|
||||
### 管理接口
|
||||
|
||||
#### 供应商管理
|
||||
@@ -259,8 +492,6 @@ GET /anthropic/v1/models
|
||||
- `PUT /api/providers/:id` - 更新供应商
|
||||
- `DELETE /api/providers/:id` - 删除供应商
|
||||
|
||||
创建供应商示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "openai",
|
||||
@@ -271,15 +502,15 @@ GET /anthropic/v1/models
|
||||
}
|
||||
```
|
||||
|
||||
**Protocol 字段说明:**
|
||||
- `protocol` 标识上游供应商使用的协议类型,可选值:`"openai"`(默认)、`"anthropic"`
|
||||
- 同协议透传时,请求体和响应体原样转发,零序列化开销
|
||||
**Protocol 字段**:标识上游供应商使用的协议类型,可选值 `"openai"`(默认)、`"anthropic"`。
|
||||
|
||||
**重要说明:**
|
||||
- `base_url` 应配置到 API 版本路径,不包含具体端点
|
||||
- OpenAI: `https://api.openai.com/v1`
|
||||
- GLM: `https://open.bigmodel.cn/api/paas/v4`
|
||||
- 其他 OpenAI 兼容供应商根据其文档配置版本路径
|
||||
**base_url 说明**:
|
||||
- OpenAI 协议:配置到 API 版本路径,如 `https://api.openai.com/v1`、`https://open.bigmodel.cn/api/paas/v4`
|
||||
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
|
||||
|
||||
**对外 URL 格式**:
|
||||
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions`、`/openai/models`、`/openai/embeddings`
|
||||
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages`、`/anthropic/v1/models`
|
||||
|
||||
#### 模型管理
|
||||
|
||||
@@ -289,76 +520,72 @@ GET /anthropic/v1/models
|
||||
- `PUT /api/models/:id` - 更新模型
|
||||
- `DELETE /api/models/:id` - 删除模型
|
||||
|
||||
创建模型示例:
|
||||
**创建请求**(id 由系统自动生成 UUID):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4"
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"unified_id": "openai/gpt-4",
|
||||
"enabled": true,
|
||||
"created_at": "2026-04-21T00:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
**统一模型 ID**:`unified_id` 字段为 `provider_id/model_name` 格式,用于代理请求的 `model` 参数。
|
||||
|
||||
#### 统计查询
|
||||
|
||||
- `GET /api/stats` - 查询统计
|
||||
- `GET /api/stats/aggregate` - 聚合统计
|
||||
|
||||
查询参数:
|
||||
查询参数:`provider_id`、`model_name`、`start_date`(YYYY-MM-DD)、`end_date`、`group_by`(provider/model/date)
|
||||
|
||||
- `provider_id` - 供应商 ID
|
||||
- `model_name` - 模型名称
|
||||
- `start_date` - 开始日期(YYYY-MM-DD)
|
||||
- `end_date` - 结束日期(YYYY-MM-DD)
|
||||
- `group_by` - 聚合维度(provider/model/date)
|
||||
#### 健康检查
|
||||
|
||||
- `GET /health` - 返回 `{"status": "ok"}`
|
||||
|
||||
## 开发
|
||||
|
||||
### 构建
|
||||
|
||||
```bash
|
||||
make build
|
||||
make build # 构建
|
||||
make lint # 代码检查
|
||||
make deps # 整理依赖
|
||||
```
|
||||
|
||||
### 代码检查
|
||||
|
||||
```bash
|
||||
make lint
|
||||
```
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Go 1.26 或更高版本
|
||||
环境要求:Go 1.26 或更高版本
|
||||
|
||||
## 公共库使用指南
|
||||
|
||||
### pkg/errors — 结构化错误
|
||||
|
||||
使用预定义的错误类型,配合 `errors.Is` / `errors.As` 判断错误:
|
||||
|
||||
```go
|
||||
import (
|
||||
"errors"
|
||||
pkgErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
// 使用预定义错误
|
||||
return pkgErrors.ErrRequestSend.WithCause(err)
|
||||
|
||||
// 判断错误类型
|
||||
var appErr *pkgErrors.AppError
|
||||
if errors.As(err, &appErr) {
|
||||
// appErr.Code, appErr.HTTPStatus, appErr.Message
|
||||
}
|
||||
```
|
||||
|
||||
可用函数:`NewAppError`、`Wrap`、`WithContext`、`WithMessage`、`AsAppError`
|
||||
|
||||
预定义错误:`ErrModelNotFound`、`ErrProviderNotFound`、`ErrInvalidRequest`、`ErrRequestCreate`、`ErrRequestSend`、`ErrResponseRead` 等
|
||||
|
||||
### pkg/logger — 日志系统
|
||||
|
||||
使用依赖注入模式,构造函数接受 `*zap.Logger` 参数,nil 时回退到 `zap.L()`:
|
||||
构造函数接受 `*zap.Logger` 参数,nil 时回退到 `zap.L()`:
|
||||
|
||||
```go
|
||||
func NewMyService(repo Repository, logger *zap.Logger) *MyService {
|
||||
@@ -369,8 +596,6 @@ func NewMyService(repo Repository, logger *zap.Logger) *MyService {
|
||||
}
|
||||
```
|
||||
|
||||
禁止直接在业务代码中使用 `zap.L()` 全局 logger,应通过构造函数注入。
|
||||
|
||||
### pkg/validator — 请求验证
|
||||
|
||||
```go
|
||||
@@ -382,8 +607,9 @@ err := v.Validate(myStruct)
|
||||
|
||||
## 编码规范
|
||||
|
||||
- **JSON 解析**:使用 `encoding/json` 标准库(`json.Unmarshal` / `json.Marshal`),不手动扫描字节
|
||||
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
|
||||
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(`strings.Contains(err.Error(), ...)`)
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(lint 强约束:errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
|
||||
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
|
||||
- **字符串分割**:使用 `strings.SplitN(key, "/", 2)` 等精确分割,不使用索引切片
|
||||
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
|
||||
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片
|
||||
|
||||
25
backend/cmd/desktop/dialog_darwin.go
Normal file
25
backend/cmd/desktop/dialog_darwin.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
|
||||
escapeAppleScript(message), escapeAppleScript(title))
|
||||
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
|
||||
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func escapeAppleScript(s string) string {
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return s
|
||||
}
|
||||
67
backend/cmd/desktop/dialog_linux.go
Normal file
67
backend/cmd/desktop/dialog_linux.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type dialogToolType int
|
||||
|
||||
const (
|
||||
toolNone dialogToolType = iota
|
||||
toolZenity
|
||||
toolKdialog
|
||||
toolNotifySend
|
||||
toolXmessage
|
||||
)
|
||||
|
||||
var (
|
||||
dialogTool dialogToolType
|
||||
dialogToolOnce sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
dialogToolOnce.Do(detectDialogTool)
|
||||
}
|
||||
|
||||
func detectDialogTool() {
|
||||
tools := []struct {
|
||||
name string
|
||||
typ dialogToolType
|
||||
}{
|
||||
{"zenity", toolZenity},
|
||||
{"kdialog", toolKdialog},
|
||||
{"notify-send", toolNotifySend},
|
||||
{"xmessage", toolXmessage},
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool.name); err == nil {
|
||||
dialogTool = tool.typ
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dialogTool = toolNone
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
switch dialogTool {
|
||||
case toolZenity:
|
||||
exec.Command("zenity", "--error",
|
||||
fmt.Sprintf("--title=%s", title),
|
||||
fmt.Sprintf("--text=%s", message)).Run()
|
||||
case toolKdialog:
|
||||
exec.Command("kdialog", "--error", message, "--title", title).Run()
|
||||
case toolNotifySend:
|
||||
exec.Command("notify-send", "-u", "critical", title, message).Run()
|
||||
case toolXmessage:
|
||||
exec.Command("xmessage", "-center",
|
||||
fmt.Sprintf("%s: %s", title, message)).Run()
|
||||
default:
|
||||
dialogLogger().Error("无法显示错误对话框")
|
||||
}
|
||||
}
|
||||
15
backend/cmd/desktop/dialog_logger.go
Normal file
15
backend/cmd/desktop/dialog_logger.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func dialogLogger() *zap.Logger {
|
||||
if zapLogger != nil {
|
||||
return zapLogger
|
||||
}
|
||||
|
||||
return pkgLogger.NewMinimal()
|
||||
}
|
||||
33
backend/cmd/desktop/dialog_windows.go
Normal file
33
backend/cmd/desktop/dialog_windows.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
MB_ICONERROR = 0x10
|
||||
MB_ICONINFORMATION = 0x40
|
||||
)
|
||||
|
||||
var (
|
||||
user32 = syscall.NewLazyDLL("user32.dll")
|
||||
procMessageBoxW = user32.NewProc("MessageBoxW")
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
messageBox(title, message, MB_ICONERROR)
|
||||
}
|
||||
|
||||
func messageBox(title, message string, flags uint) {
|
||||
titlePtr, _ := syscall.UTF16PtrFromString(title)
|
||||
messagePtr, _ := syscall.UTF16PtrFromString(message)
|
||||
procMessageBoxW.Call(
|
||||
0,
|
||||
uintptr(unsafe.Pointer(messagePtr)),
|
||||
uintptr(unsafe.Pointer(titlePtr)),
|
||||
uintptr(flags),
|
||||
)
|
||||
}
|
||||
33
backend/cmd/desktop/icon_test.go
Normal file
33
backend/cmd/desktop/icon_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"nex/embedfs"
|
||||
)
|
||||
|
||||
func TestIconSelection_Windows(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("图标格式选择测试仅在 Windows 上运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.ico"); err != nil {
|
||||
t.Fatalf("Windows 应加载 .ico 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIconSelection_NonWindows(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("图标格式选择测试在非 Windows 平台运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.png"); err != nil {
|
||||
t.Fatalf("非 Windows 平台应加载 .png 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testIconLoad(path string) error {
|
||||
_, err := embedfs.Assets.ReadFile(path)
|
||||
return err
|
||||
}
|
||||
384
backend/cmd/desktop/main.go
Normal file
384
backend/cmd/desktop/main.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/embedfs"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/database"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
server *http.Server
|
||||
zapLogger *zap.Logger
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := 9826
|
||||
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
||||
if err := singleLock.Lock(); err != nil {
|
||||
minimalLogger.Error("已有 Nex 实例运行")
|
||||
showError(appName, "已有 Nex 实例运行")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() {
|
||||
if err := singleLock.Unlock(); err != nil {
|
||||
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkPortAvailable(port); err != nil {
|
||||
minimalLogger.Error("端口不可用", zap.Error(err))
|
||||
showError(appName, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||
}
|
||||
|
||||
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
MaxBackups: cfg.Log.MaxBackups,
|
||||
MaxAge: cfg.Log.MaxAge,
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer database.Close(db)
|
||||
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
defer statsBuffer.Stop()
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
r.Use(middleware.RequestID())
|
||||
r.Use(middleware.Recovery(zapLogger))
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupStaticFiles(r)
|
||||
|
||||
server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("无法打开浏览器", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
setupSystray(port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
r.Any("/v1/*path", proxyHandler.HandleProxy)
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
providers.POST("", providerHandler.CreateProvider)
|
||||
providers.GET("/:id", providerHandler.GetProvider)
|
||||
providers.PUT("/:id", providerHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
models.POST("", modelHandler.CreateModel)
|
||||
models.GET("/:id", modelHandler.GetModel)
|
||||
models.PUT("/:id", modelHandler.UpdateModel)
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
|
||||
}
|
||||
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
if strings.HasSuffix(path, ".png") {
|
||||
return "image/png"
|
||||
}
|
||||
if strings.HasSuffix(path, ".ico") {
|
||||
return "image/x-icon"
|
||||
}
|
||||
if strings.HasSuffix(path, ".woff") || strings.HasSuffix(path, ".woff2") {
|
||||
return "font/woff2"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
}
|
||||
|
||||
func setupSystray(port int) {
|
||||
systray.Run(func() {
|
||||
var icon []byte
|
||||
var err error
|
||||
if runtime.GOOS == "windows" {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
|
||||
} else {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
|
||||
}
|
||||
if err != nil {
|
||||
zapLogger.Error("无法加载托盘图标", zap.Error(err))
|
||||
}
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTooltip(appTooltip)
|
||||
|
||||
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||
systray.AddSeparator()
|
||||
mStatus := systray.AddMenuItem("状态: 运行中", "")
|
||||
mStatus.Disable()
|
||||
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
|
||||
mPort.Disable()
|
||||
systray.AddSeparator()
|
||||
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("打开浏览器失败", zap.Error(err))
|
||||
}
|
||||
case <-mQuit.ClickedCh:
|
||||
doShutdown()
|
||||
systray.Quit()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func doShutdown() {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("正在关闭服务器...")
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
|
||||
zapLogger.Warn("关闭服务器失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if shutdownCancel != nil {
|
||||
shutdownCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func checkPortAvailable(port int) error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
||||
}
|
||||
ln.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type SingletonLock struct {
|
||||
flock *flock.Flock
|
||||
}
|
||||
|
||||
func NewSingletonLock(lockPath string) *SingletonLock {
|
||||
return &SingletonLock{
|
||||
flock: flock.New(lockPath),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Lock() error {
|
||||
locked, err := s.flock.TryLock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !locked {
|
||||
return fmt.Errorf("已有实例运行")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Unlock() error {
|
||||
return s.flock.Unlock()
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
case "linux":
|
||||
browsers := []string{"xdg-open", "google-chrome", "firefox"}
|
||||
for _, browser := range browsers {
|
||||
if _, err := exec.LookPath(browser); err == nil {
|
||||
cmd = exec.Command(browser, url)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cmd == nil {
|
||||
return fmt.Errorf("无法打开浏览器")
|
||||
}
|
||||
|
||||
return cmd.Start()
|
||||
}
|
||||
15
backend/cmd/desktop/messagebox_test.go
Normal file
15
backend/cmd/desktop/messagebox_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageBoxW_WindowsOnly(t *testing.T) {
|
||||
messageBox("测试标题", "测试消息", MB_ICONINFORMATION)
|
||||
}
|
||||
|
||||
func TestShowError_WindowsBranch(t *testing.T) {
|
||||
showError("测试错误", "这是一条测试错误消息")
|
||||
}
|
||||
8
backend/cmd/desktop/metadata.go
Normal file
8
backend/cmd/desktop/metadata.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package main
|
||||
|
||||
const (
|
||||
appName = "Nex"
|
||||
appTooltip = appName
|
||||
appDescription = "AI Gateway - 统一的大模型 API 网关"
|
||||
appWebsite = "https://github.com/nex/gateway"
|
||||
)
|
||||
13
backend/cmd/desktop/metadata_test.go
Normal file
13
backend/cmd/desktop/metadata_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDesktopMetadata(t *testing.T) {
|
||||
if appName != "Nex" {
|
||||
t.Fatalf("appName = %q, want %q", appName, "Nex")
|
||||
}
|
||||
|
||||
if appTooltip != appName {
|
||||
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
|
||||
}
|
||||
}
|
||||
69
backend/cmd/desktop/port_test.go
Normal file
69
backend/cmd/desktop/port_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCheckPortAvailable(t *testing.T) {
|
||||
port := 19826
|
||||
|
||||
err := checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口 %d 应该可用: %v", port, err)
|
||||
}
|
||||
|
||||
t.Log("端口可用测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err == nil {
|
||||
t.Fatal("端口被占用时应该返回错误")
|
||||
}
|
||||
|
||||
t.Log("端口占用检测测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
||||
port := 19828
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:19828")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{ReadHeaderTimeout: time.Second}
|
||||
defer server.Close()
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("serve failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
listener.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口关闭后应该可用: %v", err)
|
||||
}
|
||||
|
||||
t.Log("端口关闭后可用测试通过")
|
||||
}
|
||||
74
backend/cmd/desktop/singleton_test.go
Normal file
74
backend/cmd/desktop/singleton_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSingletonLock_FirstLockSuccess(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-first.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock := NewSingletonLock(lockPath)
|
||||
if err := lock.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-dup.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock1 := NewSingletonLock(lockPath)
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
err := lock2.Lock()
|
||||
if err == nil {
|
||||
if unlockErr := lock2.Unlock(); unlockErr != nil {
|
||||
t.Fatalf("解锁失败: %v", unlockErr)
|
||||
}
|
||||
t.Fatal("重复加锁应失败,但返回 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockThenRelock(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-relock.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock1 := NewSingletonLock(lockPath)
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
if err := lock2.Lock(); err != nil {
|
||||
t.Fatalf("释放后重新加锁应成功: %v", err)
|
||||
}
|
||||
if err := lock2.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
|
||||
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("未加锁时解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
123
backend/cmd/desktop/static_test.go
Normal file
123
backend/cmd/desktop/static_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/embedfs"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
|
||||
t.Run("API 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for JS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "application/javascript"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for CSS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.css", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "text/css"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Log("静态文件服务测试通过")
|
||||
}
|
||||
@@ -3,26 +3,20 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/database"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
@@ -32,17 +26,14 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 1. 加载配置
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
log.Fatalf("配置验证失败: %v", err)
|
||||
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 2. 初始化日志
|
||||
zapLogger, err := pkgLogger.New(pkgLogger.Config{
|
||||
zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
@@ -51,48 +42,57 @@ func main() {
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化日志失败: %v", err)
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// 3. 初始化数据库
|
||||
db, err := initDatabase(cfg)
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer closeDB(db)
|
||||
defer database.Close(db)
|
||||
|
||||
// 4. 初始化 repository 层
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
// 5. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
// 6. 创建 ConversionEngine
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
// 7. 初始化 provider client
|
||||
providerClient := provider.NewClient()
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
// 8. 初始化 handler 层
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
// 9. 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
@@ -103,9 +103,8 @@ func main() {
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
|
||||
// 10. 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: formatAddr(cfg.Server.Port),
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
@@ -114,7 +113,7 @@ func main() {
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -128,89 +127,17 @@ func main() {
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器强制关闭", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer.Stop()
|
||||
|
||||
zapLogger.Info("服务器已关闭")
|
||||
}
|
||||
|
||||
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
||||
|
||||
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir()
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
goose.SetDialect("sqlite3")
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrationsDir() string {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
func closeDB(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func formatAddr(port int) string {
|
||||
return fmt.Sprintf(":%d", port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
// 统一代理入口: /{protocol}/v1/{path}
|
||||
r.Any("/:protocol/v1/*path", proxyHandler.HandleProxy)
|
||||
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
|
||||
|
||||
// 供应商管理 API
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
@@ -220,7 +147,6 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
// 模型管理 API
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
@@ -230,14 +156,12 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
// 统计查询 API
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
217
backend/go.mod
217
backend/go.mod
@@ -2,58 +2,249 @@ module nex/backend
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/pressly/goose/v3 v3.27.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/zap v1.27.1
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
replace nex/embedfs => ../embedfs
|
||||
|
||||
tool (
|
||||
github.com/golangci/golangci-lint/cmd/golangci-lint
|
||||
go.uber.org/mock/mockgen
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/getlantern/systray v1.2.2
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/go-playground/validator/v10 v10.30.2
|
||||
github.com/gofrs/flock v0.13.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/pressly/goose/v3 v3.27.0
|
||||
github.com/spf13/pflag v1.0.10
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/mock v0.6.0
|
||||
go.uber.org/zap v1.27.1
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
nex/embedfs v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
require (
|
||||
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
|
||||
4d63.com/gochecknoglobals v0.2.2 // indirect
|
||||
filippo.io/edwards25519 v1.2.0 // indirect
|
||||
github.com/4meepo/tagalign v1.4.2 // indirect
|
||||
github.com/Abirdcfly/dupword v0.1.3 // indirect
|
||||
github.com/Antonboom/errname v1.0.0 // indirect
|
||||
github.com/Antonboom/nilnil v1.0.1 // indirect
|
||||
github.com/Antonboom/testifylint v1.5.2 // indirect
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/Crocmagnon/fatcontext v0.7.1 // indirect
|
||||
github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect
|
||||
github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||
github.com/OpenPeeDeeP/depguard/v2 v2.2.1 // indirect
|
||||
github.com/alecthomas/go-check-sumtype v0.3.1 // indirect
|
||||
github.com/alexkohler/nakedret/v2 v2.0.5 // indirect
|
||||
github.com/alexkohler/prealloc v1.0.0 // indirect
|
||||
github.com/alingse/asasalint v0.0.11 // indirect
|
||||
github.com/alingse/nilnesserr v0.1.2 // indirect
|
||||
github.com/ashanbrown/forbidigo v1.6.0 // indirect
|
||||
github.com/ashanbrown/makezero v1.2.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bkielbasa/cyclop v1.2.3 // indirect
|
||||
github.com/blizzy78/varnamelen v0.8.0 // indirect
|
||||
github.com/bombsimon/wsl/v4 v4.5.0 // indirect
|
||||
github.com/breml/bidichk v0.3.2 // indirect
|
||||
github.com/breml/errchkjson v0.4.0 // indirect
|
||||
github.com/butuzov/ireturn v0.3.1 // indirect
|
||||
github.com/butuzov/mirror v1.3.0 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/catenacyber/perfsprint v0.8.2 // indirect
|
||||
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charithe/durationcheck v0.0.10 // indirect
|
||||
github.com/chavacava/garif v0.1.0 // indirect
|
||||
github.com/ckaznocha/intrange v0.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/curioswitch/go-reassign v0.3.0 // indirect
|
||||
github.com/daixiang0/gci v0.13.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/denis-tingaikin/go-header v0.5.0 // indirect
|
||||
github.com/ettle/strcase v0.2.0 // indirect
|
||||
github.com/fatih/color v1.18.0 // indirect
|
||||
github.com/fatih/structtag v1.2.0 // indirect
|
||||
github.com/firefart/nonamedreturns v1.0.5 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/fzipp/gocyclo v0.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||
github.com/ghostiam/protogetter v0.3.9 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-critic/go-critic v0.12.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.2 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/go-toolsmith/astcast v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astcopy v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astequal v1.2.0 // indirect
|
||||
github.com/go-toolsmith/astfmt v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astp v1.1.0 // indirect
|
||||
github.com/go-toolsmith/strparse v1.1.0 // indirect
|
||||
github.com/go-toolsmith/typep v1.1.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect
|
||||
github.com/gobwas/glob v0.2.3 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect
|
||||
github.com/golangci/go-printf-func-name v0.1.0 // indirect
|
||||
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
|
||||
github.com/golangci/golangci-lint v1.64.8 // indirect
|
||||
github.com/golangci/misspell v0.6.0 // indirect
|
||||
github.com/golangci/plugin-module-register v0.1.1 // indirect
|
||||
github.com/golangci/revgrep v0.8.0 // indirect
|
||||
github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/gordonklaus/ineffassign v0.1.0 // indirect
|
||||
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
|
||||
github.com/gostaticanalysis/comment v1.5.0 // indirect
|
||||
github.com/gostaticanalysis/forcetypeassert v0.2.0 // indirect
|
||||
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
|
||||
github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect
|
||||
github.com/hashicorp/go-version v1.7.0 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
github.com/hexops/gotextdiff v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jgautheron/goconst v1.7.1 // indirect
|
||||
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jjti/go-spancheck v0.6.4 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/julz/importas v0.2.0 // indirect
|
||||
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
|
||||
github.com/kisielk/errcheck v1.9.0 // indirect
|
||||
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/kulti/thelper v0.6.3 // indirect
|
||||
github.com/kunwardeep/paralleltest v1.0.10 // indirect
|
||||
github.com/lasiar/canonicalheader v1.1.2 // indirect
|
||||
github.com/ldez/exptostd v0.4.2 // indirect
|
||||
github.com/ldez/gomoddirectives v0.6.1 // indirect
|
||||
github.com/ldez/grignotin v0.9.0 // indirect
|
||||
github.com/ldez/tagliatelle v0.7.1 // indirect
|
||||
github.com/ldez/usetesting v0.4.2 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/leonklingele/grouper v1.1.2 // indirect
|
||||
github.com/macabu/inamedparam v0.1.3 // indirect
|
||||
github.com/maratori/testableexamples v1.0.0 // indirect
|
||||
github.com/maratori/testpackage v1.1.1 // indirect
|
||||
github.com/matoous/godox v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/mgechev/revive v1.7.0 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/moricho/tparallel v0.3.2 // indirect
|
||||
github.com/nakabonne/nestif v0.3.1 // indirect
|
||||
github.com/nishanths/exhaustive v0.12.0 // indirect
|
||||
github.com/nishanths/predeclared v0.2.2 // indirect
|
||||
github.com/nunnatsa/ginkgolinter v0.19.1 // indirect
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
|
||||
github.com/prometheus/client_golang v1.12.1 // indirect
|
||||
github.com/prometheus/client_model v0.2.0 // indirect
|
||||
github.com/prometheus/common v0.32.1 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
|
||||
github.com/quasilyte/gogrep v0.5.0 // indirect
|
||||
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect
|
||||
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/raeperd/recvcheck v0.2.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/ryancurrah/gomodguard v1.3.5 // indirect
|
||||
github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect
|
||||
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
|
||||
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
|
||||
github.com/securego/gosec/v2 v2.22.2 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sivchari/containedctx v1.0.3 // indirect
|
||||
github.com/sivchari/tenv v1.12.1 // indirect
|
||||
github.com/sonatard/noctx v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/sourcegraph/go-diff v0.7.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/cobra v1.9.1 // indirect
|
||||
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
|
||||
github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tdakkota/asciicheck v0.4.1 // indirect
|
||||
github.com/tetafro/godot v1.5.0 // indirect
|
||||
github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect
|
||||
github.com/timonwong/loggercheck v0.10.1 // indirect
|
||||
github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect
|
||||
github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
github.com/ultraware/funlen v0.2.0 // indirect
|
||||
github.com/ultraware/whitespace v0.2.0 // indirect
|
||||
github.com/uudashr/gocognit v1.2.0 // indirect
|
||||
github.com/uudashr/iface v1.3.1 // indirect
|
||||
github.com/xen0n/gosmopolitan v1.2.2 // indirect
|
||||
github.com/yagipy/maintidx v1.0.0 // indirect
|
||||
github.com/yeya24/promlinter v0.3.0 // indirect
|
||||
github.com/ykadowak/zerologlint v0.1.5 // indirect
|
||||
gitlab.com/bosi/decorder v0.4.2 // indirect
|
||||
go-simpler.org/musttag v0.13.0 // indirect
|
||||
go-simpler.org/sloglint v0.9.0 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated // indirect
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
honnef.co/go/tools v0.6.1 // indirect
|
||||
mvdan.cc/gofumpt v0.7.0 // indirect
|
||||
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
|
||||
)
|
||||
|
||||
961
backend/go.sum
961
backend/go.sum
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
@@ -13,40 +20,49 @@ import (
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
Server ServerConfig `yaml:"server" mapstructure:"server" validate:"required"`
|
||||
Database DatabaseConfig `yaml:"database" mapstructure:"database" validate:"required"`
|
||||
Log LogConfig `yaml:"log" mapstructure:"log" validate:"required"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
Port int `yaml:"port" mapstructure:"port" validate:"required,min=1,max=65535"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout" mapstructure:"read_timeout" validate:"required"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout" mapstructure:"write_timeout" validate:"required"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"`
|
||||
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
|
||||
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
|
||||
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
|
||||
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"`
|
||||
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
|
||||
Password string `yaml:"password" mapstructure:"password"`
|
||||
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
|
||||
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
|
||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
|
||||
}
|
||||
|
||||
// LogConfig 日志配置
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Path string `yaml:"path"`
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
Compress bool `yaml:"compress"`
|
||||
Level string `yaml:"level" mapstructure:"level" validate:"required,oneof=debug info warn error"`
|
||||
Path string `yaml:"path" mapstructure:"path" validate:"required"`
|
||||
MaxSize int `yaml:"max_size" mapstructure:"max_size" validate:"required,min=1"`
|
||||
MaxBackups int `yaml:"max_backups" mapstructure:"max_backups" validate:"required,min=0"`
|
||||
MaxAge int `yaml:"max_age" mapstructure:"max_age" validate:"required,min=0"`
|
||||
Compress bool `yaml:"compress" mapstructure:"compress"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns default config values
|
||||
func DefaultConfig() *Config {
|
||||
// Use home dir for default paths
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
return &Config{
|
||||
@@ -56,7 +72,13 @@ func DefaultConfig() *Config {
|
||||
WriteTimeout: 30 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(nexDir, "config.db"),
|
||||
Host: "",
|
||||
Port: 3306,
|
||||
User: "",
|
||||
Password: "",
|
||||
DBName: "nex",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 1 * time.Hour,
|
||||
@@ -79,7 +101,7 @@ func GetConfigDir() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return configDir, nil
|
||||
@@ -103,29 +125,179 @@ func GetConfigPath() (string, error) {
|
||||
return filepath.Join(configDir, "config.yaml"), nil
|
||||
}
|
||||
|
||||
// setupDefaults 设置默认配置值
|
||||
func setupDefaults(v *viper.Viper) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
v.SetDefault("server.port", 9826)
|
||||
v.SetDefault("server.read_timeout", "30s")
|
||||
v.SetDefault("server.write_timeout", "30s")
|
||||
|
||||
v.SetDefault("database.driver", "sqlite")
|
||||
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
|
||||
v.SetDefault("database.host", "")
|
||||
v.SetDefault("database.port", 3306)
|
||||
v.SetDefault("database.user", "")
|
||||
v.SetDefault("database.password", "")
|
||||
v.SetDefault("database.dbname", "nex")
|
||||
v.SetDefault("database.max_idle_conns", 10)
|
||||
v.SetDefault("database.max_open_conns", 100)
|
||||
v.SetDefault("database.conn_max_lifetime", "1h")
|
||||
|
||||
v.SetDefault("log.level", "info")
|
||||
v.SetDefault("log.path", filepath.Join(nexDir, "log"))
|
||||
v.SetDefault("log.max_size", 100)
|
||||
v.SetDefault("log.max_backups", 10)
|
||||
v.SetDefault("log.max_age", 30)
|
||||
v.SetDefault("log.compress", true)
|
||||
}
|
||||
|
||||
// setupFlags 定义和绑定 CLI 参数
|
||||
func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
||||
// 定义所有配置项的 CLI 参数
|
||||
// 注意:这里不设置默认值,让 viper 的默认值生效
|
||||
flagSet.Int("server-port", 0, "服务器端口")
|
||||
flagSet.Duration("server-read-timeout", 0, "读超时")
|
||||
flagSet.Duration("server-write-timeout", 0, "写超时")
|
||||
|
||||
flagSet.String("database-driver", "", "数据库驱动:sqlite/mysql")
|
||||
flagSet.String("database-path", "", "数据库文件路径")
|
||||
flagSet.String("database-host", "", "MySQL 主机地址")
|
||||
flagSet.Int("database-port", 0, "MySQL 端口")
|
||||
flagSet.String("database-user", "", "MySQL 用户名")
|
||||
flagSet.String("database-password", "", "MySQL 密码")
|
||||
flagSet.String("database-dbname", "", "MySQL 数据库名")
|
||||
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
|
||||
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
|
||||
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
|
||||
|
||||
flagSet.String("log-level", "", "日志级别:debug/info/warn/error")
|
||||
flagSet.String("log-path", "", "日志文件目录")
|
||||
flagSet.Int("log-max-size", 0, "单个日志文件最大大小 MB")
|
||||
flagSet.Int("log-max-backups", 0, "保留的旧日志文件最大数量")
|
||||
flagSet.Int("log-max-age", 0, "保留旧日志文件的最大天数")
|
||||
flagSet.Bool("log-compress", false, "是否压缩旧日志文件")
|
||||
|
||||
// 绑定所有 flag 到 viper
|
||||
// 注意:必须在设置默认值之后绑定
|
||||
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
|
||||
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
|
||||
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
|
||||
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
|
||||
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
|
||||
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
|
||||
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
|
||||
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
|
||||
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
|
||||
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
|
||||
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
|
||||
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
|
||||
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
|
||||
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
|
||||
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
|
||||
}
|
||||
|
||||
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
|
||||
if err := v.BindPFlag(key, flag); err != nil {
|
||||
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
|
||||
}
|
||||
}
|
||||
|
||||
// setupEnv 绑定环境变量
|
||||
func setupEnv(v *viper.Viper) {
|
||||
v.SetEnvPrefix("NEX")
|
||||
v.AutomaticEnv()
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
}
|
||||
|
||||
// setupConfigFile 读取配置文件
|
||||
func setupConfigFile(v *viper.Viper, configPath string) error {
|
||||
v.SetConfigFile(configPath)
|
||||
v.SetConfigType("yaml")
|
||||
|
||||
// 尝试读取配置文件,如果不存在则忽略
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
// 配置文件不存在,创建默认配置文件
|
||||
writeErr := v.SafeWriteConfigAs(configPath)
|
||||
if writeErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
|
||||
if errors.As(writeErr, &alreadyExistsErr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfig loads config from YAML file, creates default if not exists
|
||||
func LoadConfig() (*Config, error) {
|
||||
configPath, err := GetConfigPath()
|
||||
if err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
return LoadConfigFromPath(configPath)
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
// LoadConfigFromPath 从指定路径加载配置
|
||||
func LoadConfigFromPath(configPath string) (*Config, error) {
|
||||
// 1. 创建 Viper 实例
|
||||
v := viper.New()
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Create default config file
|
||||
if saveErr := SaveConfig(cfg); saveErr != nil {
|
||||
return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
// 2. 定义 CLI 参数
|
||||
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
|
||||
flagSet.String("config", configPath, "配置文件路径")
|
||||
setupFlags(v, flagSet)
|
||||
|
||||
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
|
||||
if err := flagSet.Parse(os.Args[1:]); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
|
||||
}
|
||||
|
||||
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
|
||||
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
|
||||
configPath = configPathFlag
|
||||
}
|
||||
|
||||
// 5. 设置默认值
|
||||
setupDefaults(v)
|
||||
|
||||
// 6. 绑定环境变量
|
||||
setupEnv(v)
|
||||
|
||||
// 7. 读取配置文件
|
||||
if err := setupConfigFile(v, configPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 8. 反序列化到结构体
|
||||
cfg := &Config{}
|
||||
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.StringToTimeDurationHookFunc(),
|
||||
mapstructure.StringToSliceHookFunc(","),
|
||||
))); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
// 9. 验证配置
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
@@ -145,27 +317,41 @@ func SaveConfig(cfg *Config) error {
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return os.WriteFile(configPath, data, 0644)
|
||||
return os.WriteFile(configPath, data, 0o600)
|
||||
}
|
||||
|
||||
// Validate validates the config
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port))
|
||||
validate := validator.New()
|
||||
if err := validate.Struct(c); err != nil {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("配置验证失败: %v", err))
|
||||
}
|
||||
|
||||
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
|
||||
if !validLevels[c.Log.Level] {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level))
|
||||
}
|
||||
|
||||
if c.Database.Path == "" {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrintSummary 打印配置摘要
|
||||
func (c *Config) PrintSummary(logger *zap.Logger) {
|
||||
logger.Info("AI Gateway 启动配置",
|
||||
zap.Int("server_port", c.Server.Port),
|
||||
zap.String("database_driver", c.Database.Driver),
|
||||
zap.String("log_level", c.Log.Level),
|
||||
)
|
||||
|
||||
if c.Database.Driver == "mysql" {
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "mysql"),
|
||||
zap.String("host", c.Database.Host),
|
||||
zap.Int("port", c.Database.Port),
|
||||
zap.String("database", c.Database.DBName),
|
||||
)
|
||||
} else {
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "sqlite"),
|
||||
zap.String("path", c.Database.Path),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,12 @@ func TestDefaultConfig(t *testing.T) {
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
|
||||
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||
assert.Equal(t, "", cfg.Database.Host)
|
||||
assert.Equal(t, 3306, cfg.Database.Port)
|
||||
assert.Equal(t, "", cfg.Database.User)
|
||||
assert.Equal(t, "", cfg.Database.Password)
|
||||
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
@@ -46,13 +53,13 @@ func TestConfig_Validate(t *testing.T) {
|
||||
name: "端口号为0无效",
|
||||
modify: func(c *Config) { c.Server.Port = 0 },
|
||||
wantErr: true,
|
||||
errMsg: "无效的端口号",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "端口号超出范围无效",
|
||||
modify: func(c *Config) { c.Server.Port = 70000 },
|
||||
wantErr: true,
|
||||
errMsg: "无效的端口号",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "端口号为1有效",
|
||||
@@ -68,7 +75,7 @@ func TestConfig_Validate(t *testing.T) {
|
||||
name: "无效日志级别",
|
||||
modify: func(c *Config) { c.Log.Level = "invalid" },
|
||||
wantErr: true,
|
||||
errMsg: "无效的日志级别",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "debug级别有效",
|
||||
@@ -86,10 +93,75 @@ func TestConfig_Validate(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数据库路径为空无效",
|
||||
name: "SQLite模式路径为空无效",
|
||||
modify: func(c *Config) { c.Database.Path = "" },
|
||||
wantErr: true,
|
||||
errMsg: "数据库路径不能为空",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "driver值不合法",
|
||||
modify: func(c *Config) { c.Database.Driver = "postgres" },
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL配置有效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.Port = 3306
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MySQL模式host为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = ""
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式user为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = ""
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式dbname为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = ""
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式忽略path字段",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -100,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
|
||||
err := cfg.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -140,7 +214,10 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
WriteTimeout: 20 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
Port: 3306,
|
||||
DBName: "nex",
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 50,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
@@ -159,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
data, err := yaml.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
err = os.WriteFile(configPath, data, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 加载配置
|
||||
@@ -174,3 +251,72 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
|
||||
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
|
||||
}
|
||||
|
||||
func TestCLIConfig(t *testing.T) {
|
||||
// 测试 CLI 参数配置(简化版本)
|
||||
// 注意:由于 flag.Parse 只能调用一次,这里只测试配置加载流程
|
||||
t.Run("配置加载流程", func(t *testing.T) {
|
||||
// 使用默认配置路径测试
|
||||
cfg := DefaultConfig()
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// 验证默认值正确
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
assert.Equal(t, "info", cfg.Log.Level)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnvConfig(t *testing.T) {
|
||||
// 测试环境变量配置(简化版本)
|
||||
t.Run("环境变量前缀", func(t *testing.T) {
|
||||
// 验证环境变量前缀设置正确
|
||||
// 实际的环境变量测试需要独立的进程,这里只验证配置结构
|
||||
cfg := DefaultConfig()
|
||||
require.NotNil(t, cfg)
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigPriority(t *testing.T) {
|
||||
// 测试配置优先级(简化版本)
|
||||
t.Run("默认值设置", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// 验证所有默认值
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||
assert.Equal(t, 3306, cfg.Database.Port)
|
||||
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
assert.Equal(t, "info", cfg.Log.Level)
|
||||
assert.Equal(t, 100, cfg.Log.MaxSize)
|
||||
assert.Equal(t, 10, cfg.Log.MaxBackups)
|
||||
assert.Equal(t, 30, cfg.Log.MaxAge)
|
||||
assert.Equal(t, true, cfg.Log.Compress)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrintSummary(t *testing.T) {
|
||||
t.Run("SQLite模式摘要", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
t.Run("MySQL模式摘要", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Database.Driver = "mysql"
|
||||
cfg.Database.Host = "db.example.com"
|
||||
cfg.Database.Port = 3306
|
||||
cfg.Database.User = "nex"
|
||||
cfg.Database.DBName = "nex"
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,22 +6,22 @@ import (
|
||||
|
||||
// Provider 供应商模型
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model 模型配置
|
||||
// Model 模型配置(id 为 UUID 自动生成,UNIQUE(provider_id, model_name))
|
||||
type Model struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
@@ -29,8 +29,8 @@ type Model struct {
|
||||
// UsageStats 用量统计
|
||||
type UsageStats struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
|
||||
RequestCount int `gorm:"default:0" json:"request_count"`
|
||||
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
||||
}
|
||||
@@ -47,12 +47,3 @@ func (Model) TableName() string {
|
||||
func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,12 @@ type ProtocolAdapter interface {
|
||||
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
|
||||
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
|
||||
|
||||
// 统一模型 ID 相关方法
|
||||
ExtractUnifiedModelID(nativePath string) (string, error)
|
||||
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
|
||||
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
}
|
||||
|
||||
// AdapterRegistry 适配器注册表接口
|
||||
|
||||
@@ -2,6 +2,7 @@ package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -39,13 +40,13 @@ func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceTyp
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != "" && !strings.Contains(suffix, "/")
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
@@ -140,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Message: err.Message,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -203,3 +207,82 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -52,10 +53,10 @@ func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
@@ -102,9 +103,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
@@ -141,8 +142,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -150,24 +151,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.True(t, errors.As(err, &convErr))
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
@@ -178,8 +179,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
t.Run("解码重排序请求", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -187,24 +188,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码重排序响应", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序响应", func(t *testing.T) {
|
||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "some/deep/nested/model", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/messages")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", model)
|
||||
})
|
||||
|
||||
t.Run("chat_with_max_tokens", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3-opus", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
|
||||
// max_tokens is encoded as float in JSON numbers
|
||||
maxTokens, ok := m["max_tokens"]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(1024), maxTokens)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
|
||||
// other fields preserved
|
||||
_, hasContent := m["content"]
|
||||
assert.True(t, hasContent)
|
||||
assert.Equal(t, "end_turn", m["stop_reason"])
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
|
||||
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, float64(1024), m["max_tokens"])
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/v1/models/claude-3", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"messages_path", "/v1/messages", false},
|
||||
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range req.Messages {
|
||||
decoded := decodeMessage(msg)
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
|
||||
}
|
||||
|
||||
// decodeMessage 解码 Anthropic 消息
|
||||
func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var toolResults []canonical.ContentBlock
|
||||
var others []canonical.ContentBlock
|
||||
for _, b := range blocks {
|
||||
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
if len(result) == 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||
}
|
||||
return result
|
||||
return result, nil
|
||||
|
||||
case "assistant":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeContentBlocks 解码内容块列表
|
||||
func decodeContentBlocks(content any) []canonical.ContentBlock {
|
||||
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
block := decodeSingleContentBlock(m)
|
||||
block, err := decodeSingleContentBlock(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if block != nil {
|
||||
blocks = append(blocks, *block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
return blocks, nil
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSingleContentBlock 解码单个内容块
|
||||
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
t, _ := m["type"].(string)
|
||||
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}, nil
|
||||
case "tool_use":
|
||||
id, _ := m["id"].(string)
|
||||
name, _ := m["name"].(string)
|
||||
input, _ := json.Marshal(m["input"])
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
id, ok := m["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
name, ok := m["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
input, err := json.Marshal(m["input"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
|
||||
case "tool_result":
|
||||
toolUseID, _ := m["tool_use_id"].(string)
|
||||
toolUseID, ok := m["tool_use_id"].(string)
|
||||
if !ok {
|
||||
toolUseID = ""
|
||||
}
|
||||
isErr := false
|
||||
if ie, ok := m["is_error"].(bool); ok {
|
||||
isErr = ie
|
||||
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
case string:
|
||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||
default:
|
||||
content, _ = json.Marshal(cv)
|
||||
encoded, err := json.Marshal(cv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = encoded
|
||||
}
|
||||
} else {
|
||||
content = json.RawMessage(`""`)
|
||||
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
ToolUseID: toolUseID,
|
||||
Content: content,
|
||||
IsError: &isErr,
|
||||
}
|
||||
}, nil
|
||||
case "thinking":
|
||||
thinking, _ := m["thinking"].(string)
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
thinking, ok := m["thinking"].(string)
|
||||
if !ok {
|
||||
thinking = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
case "tool":
|
||||
name, _ := v["name"].(string)
|
||||
name, ok := v["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
|
||||
result = append(result, m)
|
||||
case "tool_result":
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"stop_reason": sr,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
|
||||
@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
|
||||
assert.Equal(t, true, result["stream"])
|
||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
|
||||
// tool 消息应被合并到相邻 user 消息
|
||||
foundToolResult := false
|
||||
for _, m := range msgs {
|
||||
msgMap := m.(map[string]any)
|
||||
msgMap, ok := m.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if msgMap["role"] == "user" {
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if ok {
|
||||
for _, c := range content {
|
||||
block := c.(map[string]any)
|
||||
block, ok := c.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if block["type"] == "tool_result" {
|
||||
foundToolResult = true
|
||||
}
|
||||
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", firstMsg["role"])
|
||||
}
|
||||
|
||||
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "assistant", result["role"])
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", block["type"])
|
||||
assert.Equal(t, "你好", block["text"])
|
||||
}
|
||||
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
data := result["data"].([]any)
|
||||
data, ok := result["data"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, data, 1)
|
||||
|
||||
model := data[0].(map[string]any)
|
||||
model, ok := data[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude-3-opus", model["id"])
|
||||
// created 应为 RFC3339 格式
|
||||
createdAt, ok := model["created_at"].(string)
|
||||
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
userMsg := msgs[0].(map[string]any)
|
||||
userMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", userMsg["role"])
|
||||
content := userMsg["content"].([]any)
|
||||
content, ok := userMsg["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, ok := result["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasReasoning := usage["reasoning_tokens"]
|
||||
assert.False(t, hasReasoning)
|
||||
}
|
||||
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", block["type"])
|
||||
assert.Equal(t, "tool_1", block["id"])
|
||||
assert.Equal(t, "search", block["name"])
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
switch {
|
||||
case strings.HasPrefix(line, "event: "):
|
||||
eventType = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "data: ") {
|
||||
case strings.HasPrefix(line, "data: "):
|
||||
eventData = strings.TrimPrefix(line, "data: ")
|
||||
if eventType != "" && eventData != "" {
|
||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
}
|
||||
eventType = ""
|
||||
eventData = ""
|
||||
} else if line == "" {
|
||||
// SSE 事件分隔符
|
||||
case line == "":
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
|
||||
// processContentBlockStart 处理内容块开始事件
|
||||
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
Index int `json:"index"`
|
||||
ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
|
||||
@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
checkValue string
|
||||
}{
|
||||
{
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
checkField: "text",
|
||||
checkValue: "你好",
|
||||
},
|
||||
{
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
checkField: "partial_json",
|
||||
checkValue: "{\"key\":",
|
||||
},
|
||||
{
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
checkField: "thinking",
|
||||
checkValue: "思考中",
|
||||
},
|
||||
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": tt.deltaData,
|
||||
"delta": tt.deltaData,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
|
||||
"type": "content_block_start",
|
||||
"index": 3,
|
||||
"content_block": map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "search_1",
|
||||
},
|
||||
}
|
||||
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "citations_delta",
|
||||
"citation": map[string]any{"title": "ref1"},
|
||||
"type": "citations_delta",
|
||||
"citation": map[string]any{"title": "ref1"},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
|
||||
},
|
||||
}
|
||||
deltaPayload1 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 25},
|
||||
}
|
||||
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
|
||||
assert.Equal(t, 25, events[0].Usage.OutputTokens)
|
||||
|
||||
deltaPayload2 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 30},
|
||||
}
|
||||
|
||||
@@ -80,7 +80,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", cb["type"])
|
||||
}
|
||||
|
||||
@@ -107,7 +108,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", cb["type"])
|
||||
assert.Equal(t, "toolu_1", cb["id"])
|
||||
assert.Equal(t, "search", cb["name"])
|
||||
@@ -131,7 +133,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "thinking", cb["type"])
|
||||
}
|
||||
|
||||
@@ -173,7 +176,8 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
delta := payload["delta"].(map[string]any)
|
||||
delta, okd := payload["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||
}
|
||||
|
||||
@@ -199,7 +203,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
u := payload["usage"].(map[string]any)
|
||||
u, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(88), u["output_tokens"])
|
||||
}
|
||||
|
||||
|
||||
@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_Nil(t *testing.T) {
|
||||
blocks := decodeContentBlocks(nil)
|
||||
blocks, err := decodeContentBlocks(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_String(t *testing.T) {
|
||||
blocks := decodeContentBlocks("hello")
|
||||
blocks, err := decodeContentBlocks("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "hello", blocks[0].Text)
|
||||
}
|
||||
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := encodeToolChoice(tt.choice)
|
||||
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
|
||||
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
|
||||
r, ok := result.(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.want["type"], r["type"])
|
||||
assert.Equal(t, tt.want["name"], r["name"])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tools := result["tools"].([]any)
|
||||
tools, okt := result["tools"].([]any)
|
||||
require.True(t, okt)
|
||||
assert.Len(t, tools, 1)
|
||||
tool := tools[0].(map[string]any)
|
||||
tool, okt2 := tools[0].(map[string]any)
|
||||
require.True(t, okt2)
|
||||
assert.Equal(t, "search", tool["name"])
|
||||
assert.Equal(t, "Search things", tool["description"])
|
||||
tc := result["tool_choice"].(map[string]any)
|
||||
tc, oktc := result["tool_choice"].(map[string]any)
|
||||
require.True(t, oktc)
|
||||
assert.Equal(t, "auto", tc["type"])
|
||||
}
|
||||
|
||||
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cacheRead,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cacheRead,
|
||||
CacheCreationTokens: &cacheCreation,
|
||||
},
|
||||
}
|
||||
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
|
||||
|
||||
@@ -6,22 +6,22 @@ import (
|
||||
|
||||
// MessagesRequest Anthropic Messages 请求
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
// RequestMetadata 请求元数据
|
||||
@@ -122,8 +122,8 @@ type ContentBlock struct {
|
||||
|
||||
// ResponseUsage 响应用量
|
||||
type ResponseUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
|
||||
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
|
||||
@@ -18,17 +18,17 @@ const (
|
||||
type DeltaType string
|
||||
|
||||
const (
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
)
|
||||
|
||||
// StreamDelta 流式增量联合体
|
||||
type StreamDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// StreamContentBlock 流式内容块联合体
|
||||
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
|
||||
Message *StreamMessage `json:"message,omitempty"`
|
||||
|
||||
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
|
||||
Index *int `json:"index,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
|
||||
// MessageDeltaEvent
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage *CanonicalUsage `json:"usage,omitempty"`
|
||||
|
||||
// ErrorEvent
|
||||
|
||||
@@ -40,8 +40,8 @@ type ContentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// ToolUseBlock
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// ToolResultBlock
|
||||
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
|
||||
|
||||
// OutputFormat 输出格式联合体
|
||||
type OutputFormat struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalRequest 规范请求
|
||||
type CanonicalRequest struct {
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Messages []CanonicalMessage `json:"messages"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalUsage 规范用量
|
||||
type CanonicalUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalResponse 规范响应
|
||||
type CanonicalResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetSystemString 获取系统消息字符串
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
func TestGetSystemString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system any
|
||||
want string
|
||||
name string
|
||||
system any
|
||||
want string
|
||||
}{
|
||||
{"string", "hello", "hello"},
|
||||
{"nil", nil, ""},
|
||||
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
|
||||
func TestCanonicalResponse_RoundTrip(t *testing.T) {
|
||||
sr := StopReasonEndTurn
|
||||
resp := &CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []ContentBlock{NewTextBlock("hello")},
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []ContentBlock{NewTextBlock("hello")},
|
||||
StopReason: &sr,
|
||||
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// HTTPRequestSpec HTTP 请求规格
|
||||
@@ -33,13 +35,10 @@ type ConversionEngine struct {
|
||||
|
||||
// NewConversionEngine 创建转换引擎
|
||||
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
|
||||
if logger == nil {
|
||||
logger = zap.L()
|
||||
}
|
||||
return &ConversionEngine{
|
||||
registry: registry,
|
||||
middlewareChain: NewMiddlewareChain(),
|
||||
logger: logger,
|
||||
logger: pkglogger.WithModule(logger, "conversion.engine"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,11 +78,29 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
|
||||
rewrittenBody := spec.Body
|
||||
|
||||
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
|
||||
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
|
||||
if len(spec.Body) > 0 && provider.ModelName != "" {
|
||||
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||||
if err != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||||
zap.Error(err),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
rewrittenBody = spec.Body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + nativePath,
|
||||
Method: spec.Method,
|
||||
Headers: providerAdapter.BuildHeaders(provider),
|
||||
Body: spec.Body,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -97,7 +114,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||
if err != nil {
|
||||
@@ -105,16 +122,34 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + providerUrl,
|
||||
URL: provider.BaseURL + providerURL,
|
||||
Method: spec.Method,
|
||||
Headers: providerHeaders,
|
||||
Body: providerBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConvertHttpResponse 转换 HTTP 响应
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType) (*HTTPResponseSpec, error) {
|
||||
// ConvertHttpResponse 转换 HTTP 响应,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
if modelOverride != "" && len(spec.Body) > 0 {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if rewriteErr != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||
zap.Error(rewriteErr),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
} else {
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
@@ -127,7 +162,7 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body)
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -139,9 +174,16 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateStreamConverter 创建流式转换器
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string) (StreamConverter, error) {
|
||||
// CreateStreamConverter 创建流式转换器,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||
if modelOverride != "" {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||
}
|
||||
}
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
|
||||
@@ -155,9 +197,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
}
|
||||
|
||||
ctx := ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
Timestamp: time.Now(),
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
return NewCanonicalStreamConverterWithMiddleware(
|
||||
@@ -167,6 +209,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
ctx,
|
||||
clientProtocol,
|
||||
providerProtocol,
|
||||
modelOverride,
|
||||
), nil
|
||||
}
|
||||
|
||||
@@ -192,11 +235,11 @@ func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapte
|
||||
}
|
||||
}
|
||||
|
||||
// convertResponseBody 转换响应体
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
// convertResponseBody 转换响应体,modelOverride 非空时在 canonical 层面覆写 Model 字段
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
switch interfaceType {
|
||||
case InterfaceTypeChat:
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeModels:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
||||
return body, nil
|
||||
@@ -211,12 +254,12 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeRerank:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body)
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
@@ -241,11 +284,14 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||||
}
|
||||
if modelOverride != "" {
|
||||
canonicalResp.Model = modelOverride
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
||||
@@ -256,12 +302,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
|
||||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
models, err := providerAdapter.DecodeModelsResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
@@ -270,12 +316,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
|
||||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
@@ -284,36 +330,43 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
|
||||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeRerankRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
|
||||
if decodeErr == nil {
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// DetectInterfaceType 检测接口类型
|
||||
@@ -335,8 +388,12 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
|
||||
"type": "internal_error",
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(fallback)
|
||||
return body, 500, nil
|
||||
body, marshalErr := json.Marshal(fallback)
|
||||
if marshalErr == nil {
|
||||
return body, 500, nil
|
||||
}
|
||||
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
|
||||
}
|
||||
body, statusCode := adapter.EncodeError(err)
|
||||
return body, statusCode, nil
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversionError_WithProviderProtocol(t *testing.T) {
|
||||
@@ -39,7 +40,7 @@ func TestConversionError_FullBuilder(t *testing.T) {
|
||||
|
||||
func TestEngine_Use(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
called := false
|
||||
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
called = true
|
||||
@@ -58,7 +59,7 @@ func TestEngine_Use(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
@@ -66,7 +67,7 @@ func TestEngine_Use(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return nil, errors.New("decode failed")
|
||||
@@ -75,14 +76,14 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("", "", ""))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||
@@ -91,14 +92,14 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("", "", ""))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
@@ -113,7 +114,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
|
||||
}, "client", "provider", InterfaceTypeChat)
|
||||
}, "client", "provider", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Contains(t, string(result.Body), "resp-1")
|
||||
@@ -121,7 +122,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return nil, errors.New("decode error")
|
||||
@@ -129,13 +130,13 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
|
||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat)
|
||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeEmbeddings
|
||||
@@ -158,7 +159,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeRerank
|
||||
@@ -178,7 +179,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
@@ -189,14 +190,14 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeEmbeddings)
|
||||
}, "client", "provider", InterfaceTypeEmbeddings, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
@@ -207,14 +208,14 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeRerank)
|
||||
}, "client", "provider", InterfaceTypeRerank, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeModels
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -224,7 +225,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
|
||||
body := []byte(`{"object":"list","data":[]}`)
|
||||
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/models", Method: "GET", Body: body,
|
||||
URL: "/models", Method: "GET", Body: body,
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", ""))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result.Body)
|
||||
@@ -232,7 +233,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -242,14 +243,14 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
|
||||
}, "client", "provider", InterfaceTypeModels)
|
||||
}, "client", "provider", InterfaceTypeModels, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -259,7 +260,7 @@ func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
|
||||
}, "client", "provider", InterfaceTypeModelInfo)
|
||||
}, "client", "provider", InterfaceTypeModelInfo, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
@@ -321,3 +322,58 @@ func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEv
|
||||
}
|
||||
|
||||
var _ = json.Marshal
|
||||
|
||||
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, errors.New("decode embedding failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"text-embedding","input":"hello"}`)
|
||||
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, errors.New("decode rerank failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
|
||||
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"test":"data"}`)
|
||||
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
@@ -13,16 +13,20 @@ import (
|
||||
|
||||
// mockProtocolAdapter 模拟协议适配器
|
||||
type mockProtocolAdapter struct {
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
|
||||
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
|
||||
}
|
||||
|
||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
@@ -34,8 +38,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
|
||||
|
||||
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
|
||||
@@ -124,6 +128,9 @@ func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalM
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
if m.decodeEmbeddingReqFn != nil {
|
||||
return m.decodeEmbeddingReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{}, nil
|
||||
}
|
||||
|
||||
@@ -140,6 +147,9 @@ func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalE
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
if m.decodeRerankReqFn != nil {
|
||||
return m.decodeRerankReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{}, nil
|
||||
}
|
||||
|
||||
@@ -155,23 +165,47 @@ func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRera
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteReqFn != nil {
|
||||
return m.rewriteReqFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteRespFn != nil {
|
||||
return m.rewriteRespFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// noopStreamDecoder 空流式解码器
|
||||
type noopStreamDecoder struct{}
|
||||
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
|
||||
// noopStreamEncoder 空流式编码器
|
||||
type noopStreamEncoder struct{}
|
||||
|
||||
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
|
||||
// ============ 测试用例 ============
|
||||
|
||||
func TestNewConversionEngine(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine)
|
||||
assert.Equal(t, registry, engine.GetRegistry())
|
||||
}
|
||||
@@ -179,7 +213,7 @@ func TestNewConversionEngine(t *testing.T) {
|
||||
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||
t.Run("nil_logger_uses_global", func(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine.logger)
|
||||
})
|
||||
|
||||
@@ -187,13 +221,14 @@ func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
customLogger := zap.NewNop()
|
||||
engine := NewConversionEngine(registry, customLogger)
|
||||
assert.Equal(t, customLogger, engine.logger)
|
||||
assert.NotNil(t, engine.logger)
|
||||
assert.Contains(t, engine.logger.Name(), "conversion.engine")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterAdapter(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
adapter := newMockAdapter("test-proto", true)
|
||||
err := engine.RegisterAdapter(adapter)
|
||||
@@ -205,7 +240,7 @@ func TestRegisterAdapter(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("openai", true)
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
@@ -214,7 +249,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
|
||||
|
||||
@@ -223,7 +258,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
|
||||
|
||||
assert.False(t, engine.IsPassthrough("custom", "custom"))
|
||||
@@ -231,19 +266,19 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
|
||||
func TestDetectInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("test", true)
|
||||
adapter.ifaceType = InterfaceTypeChat
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
ifaceType, err := engine.DetectInterfaceType("/v1/chat/completions", "test")
|
||||
ifaceType, err := engine.DetectInterfaceType("/chat/completions", "test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, InterfaceTypeChat, ifaceType)
|
||||
}
|
||||
|
||||
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
|
||||
assert.Error(t, err)
|
||||
@@ -251,12 +286,12 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
provider := NewTargetProvider("https://api.openai.com", "sk-test", "gpt-4")
|
||||
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions",
|
||||
URL: "/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
@@ -269,7 +304,7 @@ func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client-proto", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
@@ -301,7 +336,7 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
@@ -309,7 +344,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
Body: []byte(`{"id":"123"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat)
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
@@ -317,10 +352,10 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai")
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*PassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
@@ -328,11 +363,11 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider")
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
@@ -340,7 +375,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
@@ -352,7 +387,7 @@ func TestEncodeError(t *testing.T) {
|
||||
|
||||
func TestEncodeError_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
|
||||
@@ -380,3 +415,232 @@ func TestRegistry_GetNonExistent(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "未找到适配器")
|
||||
}
|
||||
|
||||
// ============ modelOverride 测试 ============
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{"model": resp.Model})
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"model":"native-model"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "provider/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
assert.Equal(t, "resp-1", resp["id"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 验证 chunk 改写
|
||||
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// provider adapter 解码出含 model 的流式事件
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
|
||||
canonical.NewMessageStopEvent(),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
// client adapter 编码时输出 model 字段
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"type": string(event.Type),
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
|
||||
return [][]byte{data}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证类型是 CanonicalStreamConverter
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 处理一个 chunk,验证 model 被覆写为统一模型 ID
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
|
||||
|
||||
var startEvent map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
|
||||
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
// modelOverride 为空,不应覆写
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
|
||||
}
|
||||
|
||||
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test)
|
||||
type engineTestStreamDecoder struct {
|
||||
processFn func([]byte) []canonical.CanonicalStreamEvent
|
||||
flushFn func() []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
if d.processFn != nil {
|
||||
return d.processFn(raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
if d.flushFn != nil {
|
||||
return d.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test)
|
||||
type engineTestStreamEncoder struct {
|
||||
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
|
||||
flushFn func() [][]byte
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if e.encodeFn != nil {
|
||||
return e.encodeFn(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||
if e.flushFn != nil {
|
||||
return e.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,17 +6,17 @@ import "fmt"
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
)
|
||||
|
||||
// ConversionError 协议转换错误
|
||||
|
||||
@@ -4,10 +4,10 @@ package conversion
|
||||
type InterfaceType string
|
||||
|
||||
const (
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -28,41 +29,41 @@ func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/v1/chat/completions":
|
||||
case nativePath == "/chat/completions":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/v1/models":
|
||||
case nativePath == "/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case isModelInfoPath(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
case nativePath == "/v1/embeddings":
|
||||
case nativePath == "/embeddings":
|
||||
return conversion.InterfaceTypeEmbeddings
|
||||
case nativePath == "/v1/rerank":
|
||||
case nativePath == "/rerank":
|
||||
return conversion.InterfaceTypeRerank
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id})
|
||||
// isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
if !strings.HasPrefix(path, "/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != "" && !strings.Contains(suffix, "/")
|
||||
suffix := path[len("/models/"):]
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/v1/chat/completions"
|
||||
return "/chat/completions"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/v1/models"
|
||||
return "/models"
|
||||
case conversion.InterfaceTypeEmbeddings:
|
||||
return "/v1/embeddings"
|
||||
return "/embeddings"
|
||||
case conversion.InterfaceTypeRerank:
|
||||
return "/v1/rerank"
|
||||
return "/rerank"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
@@ -137,7 +138,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -216,3 +220,92 @@ func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankRe
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return encodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
case conversion.InterfaceTypeRerank:
|
||||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||
if _, exists := m["model"]; exists {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
}
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,12 +28,12 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
|
||||
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
|
||||
{"聊天补全", "/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/models/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/rerank", conversion.InterfaceTypeRerank},
|
||||
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -48,16 +48,16 @@ func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/v1/chat/completions"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/v1/embeddings"},
|
||||
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/v1/rerank"},
|
||||
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
|
||||
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
||||
{"模型", "/models", conversion.InterfaceTypeModels, "/models"},
|
||||
{"嵌入", "/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
|
||||
{"重排序", "/rerank", conversion.InterfaceTypeRerank, "/rerank"},
|
||||
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -92,9 +92,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
@@ -118,13 +118,13 @@ func TestIsModelInfoPath(t *testing.T) {
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"model_info", "/v1/models/gpt-4", true},
|
||||
{"model_info_with_dots", "/v1/models/gpt-4.1-preview", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"nested_path", "/v1/models/gpt-4/versions", false},
|
||||
{"empty_suffix", "/v1/models/", false},
|
||||
{"unrelated", "/v1/chat/completions", false},
|
||||
{"partial_prefix", "/v1/model", false},
|
||||
{"model_info", "/models/gpt-4", true},
|
||||
{"model_info_with_dots", "/models/gpt-4.1-preview", true},
|
||||
{"models_list", "/models", false},
|
||||
{"nested_path", "/models/gpt-4/versions", true},
|
||||
{"empty_suffix", "/models/", false},
|
||||
{"unrelated", "/chat/completions", false},
|
||||
{"partial_prefix", "/model", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/chat/completions")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", model)
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", model)
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
|
||||
// messages field preserved
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "text-embedding", m["model"])
|
||||
assert.Equal(t, "hello", m["input"])
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "rerank", m["model"])
|
||||
assert.Equal(t, "test", m["query"])
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("rerank_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"rerank","results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/rerank", m["model"])
|
||||
})
|
||||
|
||||
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
|
||||
body := []byte(`{"results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
_, hasModel := m["model"]
|
||||
assert.False(t, hasModel, "rerank response without model field should not have one added")
|
||||
})
|
||||
|
||||
t.Run("embedding_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"text-embedding","data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
|
||||
body := []byte(`{"data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "text-embedding", afterRewrite)
|
||||
})
|
||||
|
||||
t.Run("rerank_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rerank", afterRewrite)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/models/gpt-4", true},
|
||||
{"unified_model_id_with_slash", "/models/openai/gpt-4", true},
|
||||
{"models_list", "/models", false},
|
||||
{"models_list_trailing_slash", "/models/", false},
|
||||
{"chat_completions", "/chat/completions", false},
|
||||
{"deeply_nested", "/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
case "image_url":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
|
||||
// contentPart 内容部分
|
||||
type contentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
}
|
||||
|
||||
// decodeContentParts 解码内容部分
|
||||
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
|
||||
var result []contentPart
|
||||
for _, item := range parts {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "text", Text: text})
|
||||
case "refusal":
|
||||
refusal, _ := m["refusal"].(string)
|
||||
refusal, ok := m["refusal"].(string)
|
||||
if !ok {
|
||||
refusal = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||
}
|
||||
}
|
||||
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "function":
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
name, _ := fn["name"].(string)
|
||||
name, ok := fn["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "custom":
|
||||
if custom, ok := v["custom"].(map[string]any); ok {
|
||||
name, _ := custom["name"].(string)
|
||||
name, ok := custom["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "allowed_tools":
|
||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||
mode, _ := at["mode"].(string)
|
||||
mode, ok := at["mode"].(string)
|
||||
if !ok {
|
||||
mode = ""
|
||||
}
|
||||
if mode == "required" {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
|
||||
case map[string]any:
|
||||
if name, ok := v["name"].(string); ok {
|
||||
req.ToolChoice = map[string]any{
|
||||
"type": "function",
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": name},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -409,3 +409,25 @@ func TestDecodeResponse_Refusal(t *testing.T) {
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello back"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
assert.Len(t, assistantMsg.Content, 1)
|
||||
assert.Equal(t, "text", assistantMsg.Content[0].Type)
|
||||
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": resp.Model,
|
||||
"usage": resp.Usage,
|
||||
"usage": resp.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 2)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "system", firstMsg["role"])
|
||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||
}
|
||||
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assistantMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assistantMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, toolCalls, 1)
|
||||
tc := toolCalls[0].(map[string]any)
|
||||
tc, ok := toolCalls[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_1", tc["id"])
|
||||
}
|
||||
|
||||
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "resp-1", result["id"])
|
||||
assert.Equal(t, "chat.completion", result["object"])
|
||||
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
choices, ok := result["choices"].([]any)
|
||||
require.True(t, ok)
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
msg, ok := choice["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "你好", msg["content"])
|
||||
assert.Equal(t, "stop", choice["finish_reason"])
|
||||
}
|
||||
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okc := result["choices"].([]any)
|
||||
require.True(t, okc)
|
||||
msgMap, okm := choices[0].(map[string]any)
|
||||
require.True(t, okm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
tcs, ok := msg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tcs, 1)
|
||||
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
data := result["data"].([]any)
|
||||
data, okd := result["data"].([]any)
|
||||
require.True(t, okd)
|
||||
assert.Len(t, data, 2)
|
||||
}
|
||||
|
||||
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
assert.Equal(t, "回答", msg["content"])
|
||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
|
||||
|
||||
idx := 0
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
|
||||
|
||||
// 先发送一个文本 chunk
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
|
||||
// 连续两个 refusal delta chunk
|
||||
for _, text := range []string{"拒绝", "原因"} {
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
|
||||
idx0 := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
|
||||
idx1 := 1
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-utf8",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-utf8",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
|
||||
|
||||
idx := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
|
||||
// StreamEncoder OpenAI 流式编码器
|
||||
type StreamEncoder struct {
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
}
|
||||
|
||||
// NewStreamEncoder 创建 OpenAI 流式编码器
|
||||
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
|
||||
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}},
|
||||
}
|
||||
return e.marshalChunk(chunk)
|
||||
|
||||
@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
data := strings.TrimPrefix(s, "data: ")
|
||||
data = strings.TrimRight(data, "\n")
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||
choices := payload["choices"].([]any)
|
||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
||||
choices, okch := payload["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
delta, okd := msgMap["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "assistant", delta["role"])
|
||||
}
|
||||
|
||||
|
||||
@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "rerank-1", result["model"])
|
||||
results := result["results"].([]any)
|
||||
results, okr := result["results"].([]any)
|
||||
require.True(t, okr)
|
||||
assert.Len(t, results, 1)
|
||||
}
|
||||
|
||||
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
reasoning := 20
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["prompt_tokens"])
|
||||
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
choice, okc := choices[0].(map[string]any)
|
||||
require.True(t, okc)
|
||||
assert.Equal(t, tt.want, choice["finish_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,42 +4,42 @@ import "encoding/json"
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completion 请求
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// 已废弃字段
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
// 已废弃
|
||||
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
|
||||
@@ -88,8 +88,8 @@ type FunctionDef struct {
|
||||
|
||||
// ResponseFormat OpenAI 响应格式
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// JSONSchemaDef JSON Schema 定义
|
||||
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
|
||||
|
||||
// Choice OpenAI 选择项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
@@ -127,10 +127,10 @@ type Choice struct {
|
||||
|
||||
// Usage OpenAI 用量
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -38,14 +38,52 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||
// 逐 chunk 改写 model 字段
|
||||
type SmartPassthroughStreamConverter struct {
|
||||
adapter ProtocolAdapter
|
||||
modelOverride string
|
||||
interfaceType InterfaceType
|
||||
}
|
||||
|
||||
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
|
||||
return &SmartPassthroughStreamConverter{
|
||||
adapter: adapter,
|
||||
modelOverride: modelOverride,
|
||||
interfaceType: interfaceType,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 改写 chunk 中的 model 字段
|
||||
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
if len(rawChunk) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
// 改写失败,返回原始 chunk
|
||||
return [][]byte{rawChunk}
|
||||
}
|
||||
|
||||
return [][]byte{rewrittenChunk}
|
||||
}
|
||||
|
||||
// Flush 无缓冲数据
|
||||
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||
type CanonicalStreamConverter struct {
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
modelOverride string
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverter 创建规范流式转换器
|
||||
@@ -57,18 +95,19 @@ func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol string) *CanonicalStreamConverter {
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
|
||||
return &CanonicalStreamConverter{
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
providerProtocol: providerProtocol,
|
||||
modelOverride: modelOverride,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 解码 → 中间件 → 编码管道
|
||||
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
|
||||
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
events := c.decoder.ProcessChunk(rawChunk)
|
||||
var result [][]byte
|
||||
@@ -80,6 +119,7 @@ func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
@@ -98,6 +138,7 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
@@ -105,3 +146,10 @@ func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
result = append(result, encoderChunks...)
|
||||
return result
|
||||
}
|
||||
|
||||
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
|
||||
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
|
||||
if c.modelOverride != "" && event.Message != nil {
|
||||
event.Message.Model = c.modelOverride
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
|
||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
@@ -143,7 +143,7 @@ func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
|
||||
@@ -163,7 +163,7 @@ func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic")
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.Flush()
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
|
||||
151
backend/internal/database/database.go
Normal file
151
backend/internal/database/database.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
moduleLogger := pkglogger.WithModule(zapLogger, "database")
|
||||
|
||||
db, err := initDB(cfg, moduleLogger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
configurePool(db, cfg, moduleLogger)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Close(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
gormLogger := pkglogger.NewGormLogger(zapLogger)
|
||||
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
}
|
||||
|
||||
switch cfg.Driver {
|
||||
case "mysql":
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 MySQL 数据库",
|
||||
zap.String("host", cfg.Host),
|
||||
zap.Int("port", cfg.Port),
|
||||
zap.String("database", cfg.DBName))
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
default:
|
||||
dbDir := filepath.Dir(cfg.Path)
|
||||
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
|
||||
}
|
||||
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
gooseDialect := "sqlite3"
|
||||
migrationsSubDir := "sqlite"
|
||||
if driver == "mysql" {
|
||||
gooseDialect = "mysql"
|
||||
migrationsSubDir = "mysql"
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir(driver)
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("执行数据库迁移",
|
||||
zap.String("dialect", gooseDialect),
|
||||
zap.String("dir", migrationsSubDir))
|
||||
}
|
||||
|
||||
if err := goose.SetDialect(gooseDialect); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
|
||||
if cfg.Driver == "sqlite" {
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("数据库连接池配置",
|
||||
zap.Int("max_idle_conns", cfg.MaxIdleConns),
|
||||
zap.Int("max_open_conns", cfg.MaxOpenConns),
|
||||
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
|
||||
}
|
||||
}
|
||||
|
||||
func getMigrationsDir(driver string) string {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
subDir := "sqlite"
|
||||
if driver == "mysql" {
|
||||
subDir = "mysql"
|
||||
}
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", subDir)
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
func BuildDSN(cfg *config.DatabaseConfig) string {
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||
}
|
||||
78
backend/internal/database/database_test.go
Normal file
78
backend/internal/database/database_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestInit_SQLite(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
defer Close(db)
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sqlDB)
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
Close(db)
|
||||
}
|
||||
|
||||
func TestBuildDSN(t *testing.T) {
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "mysql",
|
||||
Host: "db.example.com",
|
||||
Port: 3306,
|
||||
User: "nexuser",
|
||||
Password: "secretpass",
|
||||
DBName: "nexdb",
|
||||
}
|
||||
|
||||
dsn := BuildDSN(cfg)
|
||||
assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||
}
|
||||
|
||||
func TestBuildDSN_EmptyPassword(t *testing.T) {
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "mysql",
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
DBName: "nex",
|
||||
}
|
||||
|
||||
dsn := BuildDSN(cfg)
|
||||
assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
// Model 模型领域模型
|
||||
"nex/backend/pkg/modelid"
|
||||
)
|
||||
|
||||
// Model 模型领域模型(id 为 UUID 自动生成)
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
@@ -10,3 +14,8 @@ type Model struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// UnifiedModelID 返回统一模型 ID(格式:provider_id/model_name)
|
||||
func (m *Model) UnifiedModelID() string {
|
||||
return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName)
|
||||
}
|
||||
|
||||
@@ -13,12 +13,3 @@ type Provider struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,20 +6,27 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://api.test.com",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -33,11 +40,16 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
var result domain.Provider
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "p1", result.ID)
|
||||
assert.Contains(t, result.APIKey, "***")
|
||||
assert.Equal(t, "sk-test", result.APIKey)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -56,9 +68,13 @@ func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "Updated", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -72,7 +88,11 @@ func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -84,7 +104,12 @@ func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -97,7 +122,12 @@ func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -110,10 +140,17 @@ func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "m1",
|
||||
"provider_id": "p1",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
@@ -127,13 +164,16 @@ func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
|
||||
var result domain.Model
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
assert.NotEmpty(t, result.ID)
|
||||
}
|
||||
|
||||
func TestModelHandler_GetModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -149,9 +189,13 @@ func TestModelHandler_GetModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
model: &domain.Model{ID: "m1", ModelName: "gpt-4o"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -2,21 +2,22 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -24,82 +25,12 @@ func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// ============ Mock 实现 ============
|
||||
|
||||
type mockRoutingService struct {
|
||||
result *domain.RouteResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockStatsService) Record(providerID, modelName string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return m.stats, nil
|
||||
}
|
||||
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
return m.aggrResult
|
||||
}
|
||||
|
||||
type mockProviderService struct {
|
||||
provider *domain.Provider
|
||||
providers []domain.Provider
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
return m.provider, m.err
|
||||
}
|
||||
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
|
||||
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockProviderService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockModelService struct {
|
||||
model *domain.Model
|
||||
models []domain.Model
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
|
||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||
return m.model, m.err
|
||||
}
|
||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
return m.models, m.err
|
||||
}
|
||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) Send(ctx context.Context, spec interface{}) (interface{}, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStream(ctx context.Context, spec interface{}) (<-chan provider.StreamEvent, error) {
|
||||
return nil, m.err
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "p1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -112,12 +43,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
providers: []domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -127,14 +61,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -145,10 +82,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Model Handler 测试 ============
|
||||
|
||||
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "m1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -161,12 +100,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_ListModels(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
models: []domain.Model{
|
||||
{ID: "m1", ModelName: "gpt-4"},
|
||||
{ID: "m2", ModelName: "gpt-3.5"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
|
||||
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||
}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -174,16 +116,98 @@ func TestModelHandler_ListModels(t *testing.T) {
|
||||
|
||||
h.ListModels(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
require.Len(t, result, 2)
|
||||
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
|
||||
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
|
||||
}
|
||||
|
||||
// ============ Stats Handler 测试 ============
|
||||
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
|
||||
|
||||
h.GetModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "mock-uuid-1234", result.ID)
|
||||
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
}, nil)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -194,7 +218,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -205,14 +233,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
},
|
||||
aggrResult: []map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
},
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
}, nil)
|
||||
mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
})
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -222,8 +253,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ writeError 测试 ============
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -252,12 +281,13 @@ func formatMapErrors(errs map[string]string) string {
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
// ============ 错误类型判断测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
err: gorm.ErrDuplicatedKey,
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
@@ -273,3 +303,158 @@ func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "nonexistent",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商不存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// Logging 日志中间件
|
||||
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
var requestIDStr string
|
||||
if id, ok := requestID.(string); ok {
|
||||
requestIDStr = id
|
||||
}
|
||||
logger.Info("请求开始",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.String("query", query),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.Any("request_id", requestID),
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Query(query),
|
||||
pkglogger.ClientIP(c.ClientIP()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
|
||||
c.Next()
|
||||
@@ -29,12 +34,12 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
logger.Info("请求结束",
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Duration("latency", latency),
|
||||
zap.Int("body_size", c.Writer.Size()),
|
||||
zap.Any("request_id", requestID),
|
||||
pkglogger.StatusCode(statusCode),
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Latency(latency),
|
||||
pkglogger.BodySize(c.Writer.Size()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
@@ -22,40 +23,59 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
|
||||
return &ModelHandler{modelService: modelService}
|
||||
}
|
||||
|
||||
// modelResponse 模型响应 DTO,扩展 unified_id 字段
|
||||
type modelResponse struct {
|
||||
domain.Model
|
||||
UnifiedModelID string `json:"unified_id"`
|
||||
}
|
||||
|
||||
// newModelResponse 从 domain.Model 构造响应 DTO
|
||||
func newModelResponse(m *domain.Model) modelResponse {
|
||||
return modelResponse{
|
||||
Model: *m,
|
||||
UnifiedModelID: m.UnifiedModelID(),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateModel 创建模型
|
||||
func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
ProviderID string `json:"provider_id" binding:"required"`
|
||||
ModelName string `json:"model_name" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "缺少必需字段: id, provider_id, model_name",
|
||||
"error": "缺少必需字段: provider_id, model_name",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
model := &domain.Model{
|
||||
ID: req.ID,
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
}
|
||||
|
||||
err := h.modelService.Create(model)
|
||||
if err != nil {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "同一供应商下模型名称已存在",
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, model)
|
||||
c.JSON(http.StatusCreated, newModelResponse(model))
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
resp := make([]modelResponse, len(models))
|
||||
for i, m := range models {
|
||||
resp[i] = newModelResponse(&m)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
@@ -77,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
c.JSON(http.StatusOK, newModelResponse(model))
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, appErrors.ErrModelNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": appErrors.ErrDuplicateModel.Message,
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
c.JSON(http.StatusOK, newModelResponse(model))
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
@@ -135,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
@@ -55,9 +55,10 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Create(provider)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "供应商 ID 已存在",
|
||||
if errors.Is(err, appErrors.ErrInvalidProviderID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrInvalidProviderID.Message,
|
||||
"code": appErrors.ErrInvalidProviderID.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -65,7 +66,6 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
@@ -113,17 +113,24 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrImmutableField) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrImmutableField.Message,
|
||||
"code": appErrors.ErrImmutableField.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
writeError(c, err)
|
||||
return
|
||||
@@ -138,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
|
||||
@@ -3,38 +3,43 @@ package handler
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ProxyHandler 统一代理处理器
|
||||
type ProxyHandler struct {
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
providerService service.ProviderService
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProxyHandler 创建统一代理处理器
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
engine: engine,
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
providerService: providerService,
|
||||
statsService: statsService,
|
||||
logger: zap.L(),
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,12 +52,40 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 原始路径: /v1/{path}
|
||||
// 原始路径: /{path}
|
||||
path := c.Param("path")
|
||||
if strings.HasPrefix(path, "/") {
|
||||
path = path[1:]
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
nativePath := path
|
||||
|
||||
// 获取 client adapter
|
||||
registry := h.engine.GetRegistry()
|
||||
clientAdapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
return
|
||||
}
|
||||
|
||||
// 检测接口类型
|
||||
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
|
||||
// 处理 Models 接口:本地聚合
|
||||
if ifaceType == conversion.InterfaceTypeModels {
|
||||
h.handleModelsList(c, clientAdapter)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 ModelInfo 接口:本地查询
|
||||
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||||
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
|
||||
return
|
||||
}
|
||||
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||||
return
|
||||
}
|
||||
nativePath := "/v1/" + path
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
@@ -61,10 +94,17 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 model 名称(从 JSON body 中提取,GET 请求无 body)
|
||||
modelName := ""
|
||||
// 解析统一模型 ID(使用 adapter.ExtractModelName)
|
||||
var providerID, modelName string
|
||||
if len(body) > 0 {
|
||||
modelName = extractModelName(body)
|
||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||
if err == nil && unifiedID != "" {
|
||||
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err == nil {
|
||||
providerID = pid
|
||||
modelName = mn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建输入 HTTPRequestSpec
|
||||
@@ -76,7 +116,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.Route(modelName)
|
||||
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
// GET 请求或无法提取 model 时,直接转发到上游
|
||||
if len(body) == 0 || modelName == "" {
|
||||
@@ -94,28 +134,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 构建 TargetProvider
|
||||
// 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体
|
||||
// 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名)
|
||||
// 跨协议:全量转换时 ModelName 会被编码到请求体中
|
||||
targetProvider := conversion.NewTargetProvider(
|
||||
routeResult.Provider.BaseURL,
|
||||
routeResult.Provider.APIKey,
|
||||
routeResult.Model.ModelName,
|
||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||
)
|
||||
|
||||
// 判断是否流式
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
// 计算统一模型 ID(用于响应覆写)
|
||||
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||||
|
||||
if isStream {
|
||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||
} else {
|
||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult)
|
||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
@@ -123,16 +169,15 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
// 发送请求
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("发送请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换响应
|
||||
interfaceType, _ := h.engine.DetectInterfaceType(inSpec.URL, clientProtocol)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, interfaceType)
|
||||
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||||
if err != nil {
|
||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换响应失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
@@ -148,12 +193,12 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
// handleStream 处理流式请求
|
||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult) {
|
||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
@@ -161,8 +206,8 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
return
|
||||
}
|
||||
|
||||
// 创建流式转换器
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol)
|
||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
@@ -183,34 +228,50 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
|
||||
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
// flush 转换器
|
||||
chunks := streamConverter.Flush()
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
if _, err := writer.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStreamRequest 判断是否流式请求
|
||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ifaceType != conversion.InterfaceTypeChat {
|
||||
return false
|
||||
}
|
||||
@@ -224,10 +285,88 @@ func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath s
|
||||
return req.Stream
|
||||
}
|
||||
|
||||
// handleModelsList 处理 GET /v1/models 本地聚合
|
||||
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
|
||||
// 从数据库查询所有启用的模型
|
||||
models, err := h.providerService.ListEnabledModels()
|
||||
if err != nil {
|
||||
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建 CanonicalModelList
|
||||
modelList := &canonical.CanonicalModelList{
|
||||
Models: make([]canonical.CanonicalModel, 0, len(models)),
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
|
||||
ID: m.UnifiedModelID(),
|
||||
Name: m.ModelName,
|
||||
Created: m.CreatedAt.Unix(),
|
||||
OwnedBy: m.ProviderID,
|
||||
})
|
||||
}
|
||||
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelsResponse(modelList)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
|
||||
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
|
||||
// 解析统一模型 ID
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的统一模型 ID 格式",
|
||||
"code": "INVALID_MODEL_ID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库查询模型
|
||||
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建 CanonicalModelInfo
|
||||
modelInfo := &canonical.CanonicalModelInfo{
|
||||
ID: model.UnifiedModelID(),
|
||||
Name: model.ModelName,
|
||||
Created: model.CreatedAt.Unix(),
|
||||
OwnedBy: model.ProviderID,
|
||||
}
|
||||
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// writeConversionError 写入转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
||||
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
|
||||
var convErr *conversion.ConversionError
|
||||
if errors.As(err, &convErr) {
|
||||
body, statusCode, encodeErr := h.engine.EncodeError(convErr, clientProtocol)
|
||||
if encodeErr != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": encodeErr.Error()})
|
||||
return
|
||||
}
|
||||
c.Data(statusCode, "application/json", body)
|
||||
return
|
||||
}
|
||||
@@ -292,7 +431,7 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
return
|
||||
}
|
||||
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
@@ -307,17 +446,6 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
}
|
||||
|
||||
// extractModelName 从 JSON body 中提取 model
|
||||
func extractModelName(body []byte) string {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ""
|
||||
}
|
||||
return req.Model
|
||||
}
|
||||
|
||||
// extractHeaders 从 Gin context 提取请求头
|
||||
func extractHeaders(c *gin.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,9 +5,9 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// StatsHandler 统计处理器
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
pkgErrors "nex/backend/pkg/errors"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// StreamConfig 流式处理配置
|
||||
@@ -50,18 +51,20 @@ type Client struct {
|
||||
}
|
||||
|
||||
// ProviderClient 供应商客户端接口
|
||||
//
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
type ProviderClient interface {
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
||||
}
|
||||
|
||||
// NewClient 创建供应商客户端
|
||||
func NewClient() *Client {
|
||||
func NewClient(logger *zap.Logger) *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
logger: zap.L(),
|
||||
logger: pkglogger.WithModule(logger, "provider.client"),
|
||||
streamCfg: DefaultStreamConfig(),
|
||||
}
|
||||
}
|
||||
@@ -139,7 +142,10 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
cancel()
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
errBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d,读取错误响应失败: %w", resp.StatusCode, readErr)
|
||||
}
|
||||
if len(errBody) > 0 {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
|
||||
}
|
||||
@@ -182,10 +188,10 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
if isNetworkError(err) {
|
||||
c.logger.Error("流网络错误", zap.String("error", err.Error()))
|
||||
c.logger.Error("流网络错误", zap.Error(err))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
|
||||
} else {
|
||||
c.logger.Error("流读取错误", zap.String("error", err.Error()))
|
||||
c.logger.Error("流读取错误", zap.Error(err))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -13,12 +13,13 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.httpClient)
|
||||
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
|
||||
@@ -40,11 +41,12 @@ func TestClient_Send_Success(t *testing.T) {
|
||||
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
|
||||
_, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -64,11 +66,12 @@ func TestClient_Send_Success(t *testing.T) {
|
||||
func TestClient_Send_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
|
||||
_, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -82,7 +85,7 @@ func TestClient_Send_ErrorResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_Send_ConnectionError(t *testing.T) {
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: "http://localhost:1/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -99,7 +102,7 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -121,7 +124,7 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -139,18 +142,21 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
|
||||
_, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -164,11 +170,12 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
if event.Done {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
} else if event.Error != nil {
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
} else {
|
||||
default:
|
||||
dataEvents = append(dataEvents, event.Data)
|
||||
}
|
||||
}
|
||||
@@ -188,7 +195,7 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -214,11 +221,12 @@ func TestClient_Send_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"result":"ok"}`))
|
||||
_, err := w.Write([]byte(`{"result":"ok"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/models",
|
||||
Method: "GET",
|
||||
@@ -237,16 +245,18 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -260,11 +270,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
var dataCount int
|
||||
var doneCount int
|
||||
for event := range eventChan {
|
||||
if event.Done {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneCount++
|
||||
} else if event.Error != nil {
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
} else {
|
||||
default:
|
||||
dataCount++
|
||||
}
|
||||
}
|
||||
@@ -278,16 +289,18 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -363,19 +376,20 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if hijacker, ok := w.(http.Hijacker); ok {
|
||||
conn, _, _ := hijacker.Hijack()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
|
||||
@@ -2,12 +2,15 @@ package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=model_repo.go -destination=../../tests/mocks/mock_model_repository.go -package=mocks
|
||||
|
||||
// ModelRepository 模型数据仓库接口
|
||||
type ModelRepository interface {
|
||||
Create(model *domain.Model) error
|
||||
GetByID(id string) (*domain.Model, error)
|
||||
List(providerID string) ([]domain.Model, error)
|
||||
GetByModelName(modelName string) (*domain.Model, error)
|
||||
FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error)
|
||||
ListEnabled() ([]domain.Model, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
}
|
||||
|
||||
@@ -52,9 +52,9 @@ func (r *modelRepository) List(providerID string) ([]domain.Model, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error) {
|
||||
func (r *modelRepository) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||
var m config.Model
|
||||
err := r.db.Where("model_name = ?", modelName).First(&m).Error
|
||||
err := r.db.Where("provider_id = ? AND model_name = ?", providerID, modelName).First(&m).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -62,6 +62,21 @@ func (r *modelRepository) GetByModelName(modelName string) (*domain.Model, error
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (r *modelRepository) ListEnabled() ([]domain.Model, error) {
|
||||
var models []config.Model
|
||||
err := r.db.Joins("JOIN providers ON providers.id = models.provider_id").
|
||||
Where("models.enabled = ? AND providers.enabled = ?", true, true).
|
||||
Find(&models).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]domain.Model, len(models))
|
||||
for i := range models {
|
||||
result[i] = toDomainModel(&models[i])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *modelRepository) Update(id string, updates map[string]interface{}) error {
|
||||
result := r.db.Model(&config.Model{}).Where("id = ?", id).Updates(updates)
|
||||
if result.Error != nil {
|
||||
|
||||
@@ -2,6 +2,8 @@ package repository
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=provider_repo.go -destination=../../tests/mocks/mock_provider_repository.go -package=mocks
|
||||
|
||||
// ProviderRepository 供应商数据仓库接口
|
||||
type ProviderRepository interface {
|
||||
Create(provider *domain.Provider) error
|
||||
|
||||
@@ -3,10 +3,11 @@ package repository
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,30 +3,18 @@ package repository
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
testHelpers "nex/backend/tests"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
// 关闭数据库连接以便 TempDir 清理
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
return testHelpers.SetupTestDB(t)
|
||||
}
|
||||
|
||||
// ============ ProviderRepository 测试 ============
|
||||
@@ -88,7 +76,7 @@ func TestProviderRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"})
|
||||
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Old", APIKey: "key", BaseURL: "https://old.com"}))
|
||||
|
||||
err := repo.Update("p1", map[string]interface{}{"name": "New"})
|
||||
require.NoError(t, err)
|
||||
@@ -109,7 +97,7 @@ func TestProviderRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, repo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
err := repo.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -129,17 +117,21 @@ func TestProviderRepository_Delete_NotFound(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Create(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
err := repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestModelRepository_GetByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := repo.GetByID("m1")
|
||||
require.NoError(t, err)
|
||||
@@ -147,24 +139,52 @@ func TestModelRepository_GetByID(t *testing.T) {
|
||||
assert.Equal(t, "gpt-4", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelRepository_GetByModelName(t *testing.T) {
|
||||
func TestModelRepository_FindByProviderAndModelName(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
result, err := repo.GetByModelName("gpt-4")
|
||||
result, err := repo.FindByProviderAndModelName("p1", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
assert.Equal(t, "p1", result.ProviderID)
|
||||
assert.Equal(t, "gpt-4", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelRepository_FindByProviderAndModelName_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
// Wrong provider_id
|
||||
_, err := repo.FindByProviderAndModelName("p2", "gpt-4")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Wrong model_name
|
||||
_, err = repo.FindByProviderAndModelName("p1", "gpt-3.5")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Both wrong
|
||||
_, err = repo.FindByProviderAndModelName("p2", "claude-3")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestModelRepository_List(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p2", Name: "Test2", APIKey: "key", BaseURL: "https://test2.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m3", ProviderID: "p2", ModelName: "claude-3"}))
|
||||
|
||||
all, err := repo.List("")
|
||||
require.NoError(t, err)
|
||||
@@ -175,11 +195,61 @@ func TestModelRepository_List(t *testing.T) {
|
||||
assert.Len(t, p1Models, 2)
|
||||
}
|
||||
|
||||
func TestModelRepository_ListEnabled(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
modelRepo := NewModelRepository(db)
|
||||
|
||||
// Create two providers (both start enabled due to gorm:"default:true")
|
||||
err := providerRepo.Create(&domain.Provider{
|
||||
ID: "enabled-provider", Name: "Enabled Provider",
|
||||
APIKey: "key1", BaseURL: "https://enabled.com", Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = providerRepo.Create(&domain.Provider{
|
||||
ID: "disabled-provider", Name: "Disabled Provider",
|
||||
APIKey: "key2", BaseURL: "https://disabled.com", Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Disable the second provider via Update (GORM default:true skips zero values on Create)
|
||||
err = providerRepo.Update("disabled-provider", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create models (all start enabled due to gorm:"default:true")
|
||||
err = modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "enabled-provider", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
err = modelRepo.Create(&domain.Model{ID: "m2", ProviderID: "enabled-provider", ModelName: "gpt-3.5", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
err = modelRepo.Create(&domain.Model{ID: "m3", ProviderID: "disabled-provider", ModelName: "claude-3", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
err = modelRepo.Create(&domain.Model{ID: "m4", ProviderID: "disabled-provider", ModelName: "claude-3.5", Enabled: true})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Disable m2 via Update
|
||||
err = modelRepo.Update("m2", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, err)
|
||||
|
||||
// ListEnabled should only return models where both model and provider are enabled:
|
||||
// - m1: enabled provider + enabled model -> returned
|
||||
// - m2: enabled provider + disabled model -> filtered out
|
||||
// - m3: disabled provider + enabled model -> filtered out
|
||||
// - m4: disabled provider + enabled model -> filtered out
|
||||
enabled, err := modelRepo.ListEnabled()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, enabled, 1)
|
||||
assert.Equal(t, "m1", enabled[0].ID)
|
||||
assert.Equal(t, "enabled-provider", enabled[0].ProviderID)
|
||||
assert.Equal(t, "gpt-4", enabled[0].ModelName)
|
||||
}
|
||||
|
||||
func TestModelRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
err := repo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
require.NoError(t, err)
|
||||
@@ -190,9 +260,11 @@ func TestModelRepository_Update(t *testing.T) {
|
||||
|
||||
func TestModelRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
providerRepo := NewProviderRepository(db)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
require.NoError(t, repo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}))
|
||||
|
||||
err := repo.Delete("m1")
|
||||
require.NoError(t, err)
|
||||
@@ -224,10 +296,32 @@ func TestStatsRepository_Query(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewStatsRepository(db)
|
||||
|
||||
repo.Record("p1", "gpt-4")
|
||||
require.NoError(t, repo.Record("p1", "gpt-4"))
|
||||
// 注意:当前 schema 只有 date 字段有唯一约束
|
||||
// 所以同一 provider + model 只能有一条记录
|
||||
stats, err := repo.Query("p1", "", nil, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
}
|
||||
|
||||
func TestModelRepository_List_EmptyResult(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewModelRepository(db)
|
||||
|
||||
result, err := repo.List("")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result)
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
func TestProviderRepository_List_EmptyResult(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewProviderRepository(db)
|
||||
|
||||
result, err := repo.List()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Empty(t, result)
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@ import (
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=stats_repo.go -destination=../../tests/mocks/mock_stats_repository.go -package=mocks
|
||||
|
||||
// StatsRepository 统计数据仓库接口
|
||||
type StatsRepository interface {
|
||||
Record(providerID, modelName string) error
|
||||
BatchUpdate(providerID, modelName string, date time.Time, delta int) error
|
||||
Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type statsRepository struct {
|
||||
@@ -19,28 +19,46 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
|
||||
}
|
||||
|
||||
func (r *statsRepository) Record(providerID, modelName string) error {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
todayTime, _ := time.Parse("2006-01-02", today)
|
||||
now := time.Now()
|
||||
todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats config.UsageStats
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, todayTime).First(&stats).Error
|
||||
stats := config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
Date: todayTime,
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
stats = config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
Date: todayTime,
|
||||
}
|
||||
return tx.Create(&stats).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "provider_id"},
|
||||
{Name: "model_name"},
|
||||
{Name: "date"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"request_count": gorm.Expr("request_count + 1"),
|
||||
}),
|
||||
}).Create(&stats).Error
|
||||
}
|
||||
|
||||
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
|
||||
})
|
||||
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
stats := config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: delta,
|
||||
Date: date,
|
||||
}
|
||||
|
||||
return r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "provider_id"},
|
||||
{Name: "model_name"},
|
||||
{Name: "date"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"request_count": gorm.Expr("request_count + ?", delta),
|
||||
}),
|
||||
}).Create(&stats).Error
|
||||
}
|
||||
|
||||
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
|
||||
@@ -2,11 +2,14 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=model_service.go -destination=../../tests/mocks/mock_model_service.go -package=mocks
|
||||
|
||||
// ModelService 模型服务接口
|
||||
type ModelService interface {
|
||||
Create(model *domain.Model) error
|
||||
Get(id string) (*domain.Model, error)
|
||||
List(providerID string) ([]domain.Model, error)
|
||||
ListEnabled() ([]domain.Model, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
}
|
||||
|
||||
@@ -1,29 +1,44 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type modelService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) ModelService {
|
||||
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
func NewModelService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository, cache *RoutingCache) ModelService {
|
||||
return &modelService{modelRepo: modelRepo, providerRepo: providerRepo, cache: cache}
|
||||
}
|
||||
|
||||
func (s *modelService) Create(model *domain.Model) error {
|
||||
// Verify provider exists
|
||||
_, err := s.providerRepo.GetByID(model.ProviderID)
|
||||
if err != nil {
|
||||
if _, err := s.providerRepo.GetByID(model.ProviderID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
if err := s.checkDuplicateModelName(model.ProviderID, model.ModelName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
model.ID = uuid.New().String()
|
||||
model.Enabled = true
|
||||
return s.modelRepo.Create(model)
|
||||
err := s.modelRepo.Create(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.SetModel(model)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelService) Get(id string) (*domain.Model, error) {
|
||||
@@ -34,17 +49,77 @@ func (s *modelService) List(providerID string) ([]domain.Model, error) {
|
||||
return s.modelRepo.List(providerID)
|
||||
}
|
||||
|
||||
func (s *modelService) ListEnabled() ([]domain.Model, error) {
|
||||
return s.modelRepo.ListEnabled()
|
||||
}
|
||||
|
||||
func (s *modelService) Update(id string, updates map[string]interface{}) error {
|
||||
// If updating provider_id, verify new provider exists
|
||||
current, err := s.modelRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
if providerID, ok := updates["provider_id"].(string); ok {
|
||||
_, err := s.providerRepo.GetByID(providerID)
|
||||
if err != nil {
|
||||
if _, err := s.providerRepo.GetByID(providerID); err != nil {
|
||||
return appErrors.ErrProviderNotFound
|
||||
}
|
||||
}
|
||||
return s.modelRepo.Update(id, updates)
|
||||
|
||||
newProviderID := current.ProviderID
|
||||
if v, ok := updates["provider_id"].(string); ok {
|
||||
newProviderID = v
|
||||
}
|
||||
newModelName := current.ModelName
|
||||
if v, ok := updates["model_name"].(string); ok {
|
||||
newModelName = v
|
||||
}
|
||||
|
||||
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||
if err := s.checkDuplicateModelName(newProviderID, newModelName, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.modelRepo.Update(id, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.InvalidateModel(current.ProviderID, current.ModelName)
|
||||
if newProviderID != current.ProviderID || newModelName != current.ModelName {
|
||||
s.cache.InvalidateModel(newProviderID, newModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(id string) error {
|
||||
return s.modelRepo.Delete(id)
|
||||
model, err := s.modelRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return appErrors.ErrModelNotFound
|
||||
}
|
||||
|
||||
err = s.modelRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cache.InvalidateModel(model.ProviderID, model.ModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDuplicateModelName 校验同一供应商下 model_name 是否重复
|
||||
// excludeID 用于更新时排除自身
|
||||
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
|
||||
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil // 未找到,不重复
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
if excludeID != "" && existing.ID == excludeID {
|
||||
return nil // 排除自身
|
||||
}
|
||||
return appErrors.ErrDuplicateModel
|
||||
}
|
||||
|
||||
@@ -2,11 +2,16 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=provider_service.go -destination=../../tests/mocks/mock_provider_service.go -package=mocks
|
||||
|
||||
// ProviderService 供应商服务接口
|
||||
type ProviderService interface {
|
||||
Create(provider *domain.Provider) error
|
||||
Get(id string, maskKey bool) (*domain.Provider, error)
|
||||
Get(id string) (*domain.Provider, error)
|
||||
List() ([]domain.Provider, error)
|
||||
Update(id string, updates map[string]interface{}) error
|
||||
Delete(id string) error
|
||||
// 统一模型 ID 相关方法
|
||||
ListEnabledModels() ([]domain.Model, error)
|
||||
GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error)
|
||||
}
|
||||
|
||||
@@ -1,49 +1,85 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type providerService struct {
|
||||
providerRepo repository.ProviderRepository
|
||||
modelRepo repository.ModelRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewProviderService(providerRepo repository.ProviderRepository) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo}
|
||||
func NewProviderService(providerRepo repository.ProviderRepository, modelRepo repository.ModelRepository, cache *RoutingCache) ProviderService {
|
||||
return &providerService{providerRepo: providerRepo, modelRepo: modelRepo, cache: cache}
|
||||
}
|
||||
|
||||
func (s *providerService) Create(provider *domain.Provider) error {
|
||||
if err := modelid.ValidateProviderID(provider.ID); err != nil {
|
||||
return appErrors.ErrInvalidProviderID
|
||||
}
|
||||
provider.Enabled = true
|
||||
return s.providerRepo.Create(provider)
|
||||
err := s.providerRepo.Create(provider)
|
||||
if err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return appErrors.ErrConflict
|
||||
}
|
||||
return err
|
||||
}
|
||||
s.cache.SetProvider(provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *providerService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
provider, err := s.providerRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if maskKey {
|
||||
provider.MaskAPIKey()
|
||||
}
|
||||
return provider, nil
|
||||
func (s *providerService) Get(id string) (*domain.Provider, error) {
|
||||
return s.providerRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *providerService) List() ([]domain.Provider, error) {
|
||||
providers, err := s.providerRepo.List()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range providers {
|
||||
providers[i].MaskAPIKey()
|
||||
}
|
||||
return providers, nil
|
||||
return s.providerRepo.List()
|
||||
}
|
||||
|
||||
func (s *providerService) Update(id string, updates map[string]interface{}) error {
|
||||
return s.providerRepo.Update(id, updates)
|
||||
if _, ok := updates["id"]; ok {
|
||||
return appErrors.ErrImmutableField
|
||||
}
|
||||
err := s.providerRepo.Update(id, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.InvalidateProvider(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *providerService) Delete(id string) error {
|
||||
return s.providerRepo.Delete(id)
|
||||
err := s.providerRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.InvalidateProvider(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListEnabledModels 返回所有启用的模型(用于 Models 接口本地聚合)
|
||||
func (s *providerService) ListEnabledModels() ([]domain.Model, error) {
|
||||
return s.modelRepo.ListEnabled()
|
||||
}
|
||||
|
||||
// GetModelByProviderAndName 按 provider_id 和 model_name 查询模型(用于 ModelInfo 接口本地查询)
|
||||
func (s *providerService) GetModelByProviderAndName(providerID, modelName string) (*domain.Model, error) {
|
||||
return s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
}
|
||||
|
||||
// isUniqueConstraintError 判断是否为数据库唯一约束冲突错误
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "unique constraint") || strings.Contains(msg, "duplicate")
|
||||
}
|
||||
|
||||
149
backend/internal/service/routing_cache.go
Normal file
149
backend/internal/service/routing_cache.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type RoutingCache struct {
|
||||
providers sync.Map
|
||||
models sync.Map
|
||||
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewRoutingCache(
|
||||
modelRepo repository.ModelRepository,
|
||||
providerRepo repository.ProviderRepository,
|
||||
logger *zap.Logger,
|
||||
) *RoutingCache {
|
||||
return &RoutingCache{
|
||||
modelRepo: modelRepo,
|
||||
providerRepo: providerRepo,
|
||||
logger: pkglogger.WithModule(logger, "service.routing_cache"),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
if provider, ok := v.(*domain.Provider); ok {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
|
||||
provider, err := c.providerRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
if provider, ok := v.(*domain.Provider); ok {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.providers.Store(id, provider)
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, error) {
|
||||
key := providerID + "/" + modelName
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
if model, ok := v.(*domain.Model); ok {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
if model, ok := v.(*domain.Model); ok {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.models.Store(key, model)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (c *RoutingCache) SetProvider(provider *domain.Provider) {
|
||||
c.providers.Store(provider.ID, provider)
|
||||
}
|
||||
|
||||
func (c *RoutingCache) SetModel(model *domain.Model) {
|
||||
key := model.ProviderID + "/" + model.ModelName
|
||||
c.models.Store(key, model)
|
||||
}
|
||||
|
||||
func (c *RoutingCache) InvalidateProvider(id string) {
|
||||
c.providers.Delete(id)
|
||||
c.invalidateModelsByProvider(id)
|
||||
c.logger.Debug("Provider 缓存失效", zap.String("provider_id", id))
|
||||
}
|
||||
|
||||
func (c *RoutingCache) InvalidateModel(providerID, modelName string) {
|
||||
key := providerID + "/" + modelName
|
||||
c.models.Delete(key)
|
||||
c.logger.Debug("Model 缓存失效",
|
||||
zap.String("provider_id", providerID),
|
||||
zap.String("model_name", modelName))
|
||||
}
|
||||
|
||||
func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
|
||||
prefix := providerID + "/"
|
||||
count := 0
|
||||
c.models.Range(func(key, value interface{}) bool {
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(keyStr, prefix) {
|
||||
c.models.Delete(key)
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
c.logger.Debug("清除 Provider 相关 Model 缓存",
|
||||
zap.String("provider_id", providerID),
|
||||
zap.Int("count", count))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoutingCache) Preload() error {
|
||||
providers, err := c.providerRepo.List()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range providers {
|
||||
c.providers.Store(providers[i].ID, &providers[i])
|
||||
}
|
||||
|
||||
models, err := c.modelRepo.List("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range models {
|
||||
key := models[i].ProviderID + "/" + models[i].ModelName
|
||||
c.models.Store(key, &models[i])
|
||||
}
|
||||
|
||||
c.logger.Info("缓存预热完成",
|
||||
zap.Int("providers", len(providers)),
|
||||
zap.Int("models", len(models)))
|
||||
return nil
|
||||
}
|
||||
274
backend/internal/service/routing_cache_test.go
Normal file
274
backend/internal/service/routing_cache_test.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type mockModelRepo struct {
|
||||
models map[string]*domain.Model
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Create(model *domain.Model) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) GetByID(id string) (*domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) FindByProviderAndModelName(providerID, modelName string) (*domain.Model, error) {
|
||||
key := providerID + "/" + modelName
|
||||
if model, ok := m.models[key]; ok {
|
||||
return model, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) List(providerID string) ([]domain.Model, error) {
|
||||
var result []domain.Model
|
||||
for _, model := range m.models {
|
||||
if providerID == "" || model.ProviderID == providerID {
|
||||
result = append(result, *model)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) ListEnabled() ([]domain.Model, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Update(id string, updates map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockModelRepo) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockProviderRepo struct {
|
||||
providers map[string]*domain.Provider
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Create(provider *domain.Provider) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) GetByID(id string) (*domain.Provider, error) {
|
||||
if provider, ok := m.providers[id]; ok {
|
||||
return provider, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) List() ([]domain.Provider, error) {
|
||||
var result []domain.Provider
|
||||
for _, provider := range m.providers {
|
||||
result = append(result, *provider)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Update(id string, updates map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockProviderRepo) Delete(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetProvider_CacheHit(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
provider, err := cache.GetProvider("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai", provider.ID)
|
||||
|
||||
provider2, err := cache.GetProvider("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, provider, provider2)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetProvider_CacheMiss(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetProvider("notexist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetModel_CacheHit(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
model, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", model.ModelName)
|
||||
|
||||
model2, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model, model2)
|
||||
}
|
||||
|
||||
func TestRoutingCache_GetModel_CacheMiss(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "notexist")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingCache_DoubleCheck(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: make(map[string]*domain.Model)}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := cache.GetProvider("openai")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
"openai/gpt-3.5": {ID: "2", ProviderID: "openai", ModelName: "gpt-3.5", Enabled: true},
|
||||
"anthropic/claude": {ID: "3", ProviderID: "anthropic", ModelName: "claude", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
"anthropic": {ID: "anthropic", Name: "Anthropic", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
_, err = cache.GetModel("openai", "gpt-3.5")
|
||||
require.NoError(t, err)
|
||||
_, err = cache.GetModel("anthropic", "claude")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.InvalidateProvider("openai")
|
||||
|
||||
var openaiCount, anthropicCount int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
keyStr, ok := key.(string)
|
||||
if ok && keyStr == "anthropic/claude" {
|
||||
anthropicCount++
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 0, openaiCount)
|
||||
assert.Equal(t, 1, anthropicCount)
|
||||
}
|
||||
|
||||
func TestRoutingCache_InvalidateModel(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: make(map[string]*domain.Provider)}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
_, err := cache.GetModel("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.InvalidateModel("openai", "gpt-4")
|
||||
|
||||
var count int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestRoutingCache_Preload_Success(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
err := cache.Preload()
|
||||
require.NoError(t, err)
|
||||
|
||||
var providerCount, modelCount int
|
||||
cache.providers.Range(func(key, value interface{}) bool {
|
||||
providerCount++
|
||||
return true
|
||||
})
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
modelCount++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 1, providerCount)
|
||||
assert.Equal(t, 1, modelCount)
|
||||
}
|
||||
|
||||
func TestRoutingCache_ConcurrentAccess(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
modelRepo := &mockModelRepo{models: map[string]*domain.Model{
|
||||
"openai/gpt-4": {ID: "1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true},
|
||||
}}
|
||||
providerRepo := &mockProviderRepo{providers: map[string]*domain.Provider{
|
||||
"openai": {ID: "openai", Name: "OpenAI", Enabled: true},
|
||||
}}
|
||||
|
||||
cache := NewRoutingCache(modelRepo, providerRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetProvider("openai")
|
||||
_, _ = cache.GetModel("openai", "gpt-4")
|
||||
cache.InvalidateProvider("openai")
|
||||
cache.InvalidateModel("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -2,7 +2,9 @@ package service
|
||||
|
||||
import "nex/backend/internal/domain"
|
||||
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=routing_service.go -destination=../../tests/mocks/mock_routing_service.go -package=mocks
|
||||
|
||||
// RoutingService 路由服务接口
|
||||
type RoutingService interface {
|
||||
Route(modelName string) (*domain.RouteResult, error)
|
||||
RouteByModelName(providerID, modelName string) (*domain.RouteResult, error)
|
||||
}
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type routingService struct {
|
||||
modelRepo repository.ModelRepository
|
||||
providerRepo repository.ProviderRepository
|
||||
cache *RoutingCache
|
||||
}
|
||||
|
||||
func NewRoutingService(modelRepo repository.ModelRepository, providerRepo repository.ProviderRepository) RoutingService {
|
||||
return &routingService{modelRepo: modelRepo, providerRepo: providerRepo}
|
||||
func NewRoutingService(cache *RoutingCache) RoutingService {
|
||||
return &routingService{cache: cache}
|
||||
}
|
||||
|
||||
func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
model, err := s.modelRepo.GetByModelName(modelName)
|
||||
func (s *routingService) RouteByModelName(providerID, modelName string) (*domain.RouteResult, error) {
|
||||
model, err := s.cache.GetModel(providerID, modelName)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrModelNotFound
|
||||
}
|
||||
@@ -26,7 +23,7 @@ func (s *routingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
return nil, appErrors.ErrModelDisabled
|
||||
}
|
||||
|
||||
provider, err := s.providerRepo.GetByID(model.ProviderID)
|
||||
provider, err := s.cache.GetProvider(model.ProviderID)
|
||||
if err != nil {
|
||||
return nil, appErrors.ErrProviderNotFound
|
||||
}
|
||||
|
||||
@@ -3,24 +3,27 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestProviderService_Update(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, svc.Create(&domain.Provider{ID: "p1", Name: "Original", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
|
||||
err := svc.Update("p1", map[string]interface{}{"name": "Updated"})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := svc.Get("p1", false)
|
||||
result, err := svc.Get("p1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", result.Name)
|
||||
}
|
||||
@@ -28,7 +31,9 @@ func TestProviderService_Update(t *testing.T) {
|
||||
func TestProviderService_Update_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent", map[string]interface{}{"name": "test"})
|
||||
assert.Error(t, err)
|
||||
@@ -38,43 +43,49 @@ func TestModelService_Get(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
model, err := svc.Get("m1")
|
||||
result, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", model.ModelName)
|
||||
assert.Equal(t, "gpt-4", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelService_Update(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
err := svc.Update("m1", map[string]interface{}{"model_name": "gpt-4o"})
|
||||
err := svc.Update(model.ID, map[string]interface{}{"model_name": "gpt-4o"})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := svc.Get("m1")
|
||||
result, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4o", model.ModelName)
|
||||
assert.Equal(t, "gpt-4o", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelService_Update_ProviderID_Invalid(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
err := svc.Update("m1", map[string]interface{}{"provider_id": "nonexistent"})
|
||||
err := svc.Update(model.ID, map[string]interface{}{"provider_id": "nonexistent"})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -82,15 +93,17 @@ func TestModelService_Delete(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"}))
|
||||
model := &domain.Model{ProviderID: "p1", ModelName: "gpt-4"}
|
||||
require.NoError(t, svc.Create(model))
|
||||
|
||||
err := svc.Delete("m1")
|
||||
err := svc.Delete(model.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = svc.Get("m1")
|
||||
_, err = svc.Get(model.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -98,7 +111,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Delete("nonexistent")
|
||||
assert.Error(t, err)
|
||||
@@ -106,7 +120,8 @@ func TestModelService_Delete_NotFound(t *testing.T) {
|
||||
|
||||
func TestStatsService_Aggregate_Default(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
@@ -118,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
|
||||
|
||||
totalCount := 0
|
||||
for _, r := range result {
|
||||
totalCount += r["request_count"].(int)
|
||||
count, ok := r["request_count"].(int)
|
||||
require.True(t, ok)
|
||||
totalCount += count
|
||||
}
|
||||
assert.Equal(t, 15, totalCount)
|
||||
}
|
||||
@@ -127,7 +144,8 @@ func TestModelService_Update_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
err := svc.Update("nonexistent", map[string]interface{}{"model_name": "test"})
|
||||
assert.Error(t, err)
|
||||
|
||||
@@ -1,270 +1,507 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
func setupServiceTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
db, err := gorm.Open(sqlite.Open(dir+"/test.db"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&config.Provider{}, &config.Model{}, &config.UsageStats{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
sqlDB, _ := db.DB()
|
||||
if sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
})
|
||||
return db
|
||||
return testHelpers.SetupTestDB(t)
|
||||
}
|
||||
|
||||
// ============ ProviderService 测试 ============
|
||||
|
||||
func TestProviderService_Create(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: "test-p", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com",
|
||||
}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, provider.Enabled)
|
||||
func setupRoutingCache(t *testing.T, db *gorm.DB) *RoutingCache {
|
||||
t.Helper()
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
return NewRoutingCache(modelRepo, providerRepo, zap.NewNop())
|
||||
}
|
||||
|
||||
func TestProviderService_Get_MaskKey(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
// ============ RoutingService - RouteByModelName 测试 ============
|
||||
|
||||
svc.Create(&domain.Provider{
|
||||
ID: "p1", Name: "Test", APIKey: "sk-long-api-key-12345", BaseURL: "https://test.com",
|
||||
})
|
||||
|
||||
result, err := svc.Get("p1", true)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "***2345", result.APIKey)
|
||||
|
||||
result, err = svc.Get("p1", false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sk-long-api-key-12345", result.APIKey)
|
||||
}
|
||||
|
||||
func TestProviderService_List(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key123", BaseURL: "https://a.com"})
|
||||
svc.Create(&domain.Provider{ID: "p2", Name: "P2", APIKey: "key456", BaseURL: "https://b.com"})
|
||||
|
||||
providers, err := svc.List()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, providers, 2)
|
||||
assert.Contains(t, providers[0].APIKey, "***")
|
||||
}
|
||||
|
||||
func TestProviderService_Delete(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
svc := NewProviderService(repo)
|
||||
|
||||
svc.Create(&domain.Provider{ID: "p1", Name: "Test", APIKey: "key", BaseURL: "https://test.com"})
|
||||
err := svc.Delete("p1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = svc.Get("p1", false)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ ModelService 测试 ============
|
||||
|
||||
func TestModelService_Create(t *testing.T) {
|
||||
func TestRoutingService_RouteByModelName_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"}
|
||||
result, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai", result.Provider.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||
}
|
||||
|
||||
func TestRoutingService_RouteByModelName_NotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "nonexistent-model")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
}
|
||||
|
||||
func TestRoutingService_RouteByModelName_DisabledModel(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Update("m1", map[string]interface{}{"enabled": false}))
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelDisabled))
|
||||
}
|
||||
|
||||
func TestRoutingService_RouteByModelName_DisabledProvider(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewRoutingService(cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com", Enabled: true}))
|
||||
require.NoError(t, modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4", Enabled: true}))
|
||||
require.NoError(t, providerRepo.Update("openai", map[string]interface{}{"enabled": false}))
|
||||
|
||||
_, err := svc.RouteByModelName("openai", "gpt-4")
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderDisabled))
|
||||
}
|
||||
|
||||
// ============ ModelService - Create with UUID 测试 ============
|
||||
|
||||
func TestModelService_Create_GeneratesUUID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, model.Enabled)
|
||||
|
||||
assert.NotEmpty(t, model.ID)
|
||||
_, err = uuid.Parse(model.ID)
|
||||
assert.NoError(t, err, "model.ID should be a valid UUID")
|
||||
|
||||
stored, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.ID, stored.ID)
|
||||
assert.Equal(t, "gpt-4", stored.ModelName)
|
||||
}
|
||||
|
||||
func TestModelService_Create_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
|
||||
model2 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err = svc.Create(model2)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
}
|
||||
|
||||
func TestModelService_Create_ProviderNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
model := &domain.Model{ID: "m1", ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
model := &domain.Model{ProviderID: "nonexistent", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrProviderNotFound))
|
||||
}
|
||||
|
||||
func TestModelService_List(t *testing.T) {
|
||||
// ============ ProviderService - Create with validation 测试 ============
|
||||
|
||||
func TestProviderService_Create_InvalidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewModelService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com"})
|
||||
svc.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4"})
|
||||
svc.Create(&domain.Model{ID: "m2", ProviderID: "p1", ModelName: "gpt-3.5"})
|
||||
provider := &domain.Provider{ID: "open-ai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
assert.True(t, errors.Is(err, appErrors.ErrInvalidProviderID))
|
||||
}
|
||||
|
||||
models, err := svc.List("p1")
|
||||
func TestProviderService_Create_ValidID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, models, 2)
|
||||
assert.Equal(t, "openai", provider.ID)
|
||||
assert.True(t, provider.Enabled)
|
||||
}
|
||||
|
||||
// ============ RoutingService 测试 ============
|
||||
// ============ ModelService - Update with duplicate check 测试 ============
|
||||
|
||||
func TestRoutingService_Route(t *testing.T) {
|
||||
func TestModelService_Update_DuplicateModelName(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "key", BaseURL: "https://api.anthropic.com"}))
|
||||
|
||||
result, err := svc.Route("gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "p1", result.Provider.ID)
|
||||
assert.Equal(t, "gpt-4", result.Model.ModelName)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
_, err := svc.Route("nonexistent-model")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ModelDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
// 先创建启用的模型,然后通过 Update 禁用
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
modelRepo.Update("m1", map[string]interface{}{"enabled": false})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRoutingService_Route_ProviderDisabled(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
svc := NewRoutingService(modelRepo, providerRepo)
|
||||
|
||||
// 先创建启用的 provider,然后禁用
|
||||
providerRepo.Create(&domain.Provider{ID: "p1", Name: "P1", APIKey: "key", BaseURL: "https://test.com", Enabled: true})
|
||||
providerRepo.Update("p1", map[string]interface{}{"enabled": false})
|
||||
modelRepo.Create(&domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true})
|
||||
|
||||
_, err := svc.Route("gpt-4")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ============ StatsService 测试 ============
|
||||
|
||||
func TestStatsService_RecordAndGet(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
err := svc.Record("p1", "gpt-4")
|
||||
model1 := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model1)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats, err := svc.Get("p1", "", nil, nil)
|
||||
model2 := &domain.Model{ProviderID: "anthropic", ModelName: "claude-3"}
|
||||
err = svc.Create(model2)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
|
||||
// 将 model2 的 model_name 改为 "gpt-4" 且 provider_id 改为 "openai",与 model1 冲突
|
||||
err = svc.Update(model2.ID, map[string]interface{}{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrDuplicateModel))
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByProvider(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
func TestModelService_Update_ModelNotFound(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "p1", ModelName: "gpt-3.5", RequestCount: 5},
|
||||
{ProviderID: "p2", ModelName: "claude-3", RequestCount: 8},
|
||||
}
|
||||
|
||||
result := svc.Aggregate(stats, "provider")
|
||||
assert.Len(t, result, 2)
|
||||
|
||||
p1Count := 0
|
||||
p2Count := 0
|
||||
for _, r := range result {
|
||||
if r["provider_id"] == "p1" {
|
||||
p1Count = r["request_count"].(int)
|
||||
}
|
||||
if r["provider_id"] == "p2" {
|
||||
p2Count = r["request_count"].(int)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 15, p1Count)
|
||||
assert.Equal(t, 8, p2Count)
|
||||
err := svc.Update("nonexistent-id", map[string]interface{}{
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrModelNotFound))
|
||||
}
|
||||
|
||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
func TestModelService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
{ProviderID: "p2", RequestCount: 5},
|
||||
}
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
result := svc.Aggregate(stats, "date")
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, 15, result[0]["request_count"])
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
err := svc.Create(model)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 更新 model_name 为不冲突的值
|
||||
err = svc.Update(model.ID, map[string]interface{}{
|
||||
"model_name": "gpt-4-turbo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.Get(model.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4-turbo", updated.ModelName)
|
||||
}
|
||||
|
||||
// ============ ProviderService - Update immutable ID 测试 ============
|
||||
|
||||
func TestProviderService_Update_ImmutableID(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 尝试更新 id 字段
|
||||
err = svc.Update("openai", map[string]interface{}{
|
||||
"id": "new-id",
|
||||
})
|
||||
assert.True(t, errors.Is(err, appErrors.ErrImmutableField))
|
||||
}
|
||||
|
||||
func TestProviderService_Update_Success(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}
|
||||
err := svc.Create(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 更新 name
|
||||
err = svc.Update("openai", map[string]interface{}{
|
||||
"name": "OpenAI Updated",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := svc.Get("openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OpenAI Updated", updated.Name)
|
||||
}
|
||||
|
||||
// ============ StatsService - Aggregate ByModel 测试 ============
|
||||
|
||||
func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
svc := NewStatsService(statsRepo)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "openai", ModelName: "gpt-3.5", RequestCount: 5},
|
||||
{ProviderID: "anthropic", ModelName: "claude-3", RequestCount: 8},
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 3},
|
||||
tests := []struct {
|
||||
name string
|
||||
stats []domain.UsageStats
|
||||
expected []map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "multiple providers with same model name",
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "azure", ModelName: "gpt-4", RequestCount: 20},
|
||||
{ProviderID: "openai", ModelName: "gpt-4", RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"provider_id": "openai", "model_name": "gpt-4", "request_count": 15},
|
||||
{"provider_id": "azure", "model_name": "gpt-4", "request_count": 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty providerID",
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "", ModelName: "gpt-4", RequestCount: 10},
|
||||
{ProviderID: "", ModelName: "gpt-4", RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"provider_id": "", "model_name": "gpt-4", "request_count": 15},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty result set",
|
||||
stats: []domain.UsageStats{},
|
||||
expected: []map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
result := svc.Aggregate(stats, "model")
|
||||
assert.Len(t, result, 3)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
// 验证每个 provider/model 组合的计数
|
||||
counts := make(map[string]int)
|
||||
for _, r := range result {
|
||||
key := r["provider_id"].(string) + "/" + r["model_name"].(string)
|
||||
counts[key] = r["request_count"].(int)
|
||||
result := svc.Aggregate(tt.stats, "model")
|
||||
|
||||
assert.Len(t, result, len(tt.expected))
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, r := range result {
|
||||
if r["provider_id"] == exp["provider_id"] && r["model_name"] == exp["model_name"] {
|
||||
assert.Equal(t, exp["request_count"], r["request_count"])
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected result not found: %v", exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
assert.Equal(t, 13, counts["openai/gpt-4"])
|
||||
assert.Equal(t, 5, counts["openai/gpt-3.5"])
|
||||
assert.Equal(t, 8, counts["anthropic/claude-3"])
|
||||
}
|
||||
|
||||
// ============ StatsService - Aggregate ByDate 测试 ============
|
||||
|
||||
func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stats []domain.UsageStats
|
||||
expected []map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "normal date grouping",
|
||||
stats: []domain.UsageStats{
|
||||
{Date: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), RequestCount: 10},
|
||||
{Date: time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC), RequestCount: 5},
|
||||
{Date: time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC), RequestCount: 20},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"date": "2024-01-01", "request_count": 15},
|
||||
{"date": "2024-01-02", "request_count": 20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zero-value time",
|
||||
stats: []domain.UsageStats{
|
||||
{Date: time.Time{}, RequestCount: 10},
|
||||
{Date: time.Time{}, RequestCount: 5},
|
||||
},
|
||||
expected: []map[string]interface{}{
|
||||
{"date": "0001-01-01", "request_count": 15},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty result set",
|
||||
stats: []domain.UsageStats{},
|
||||
expected: []map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "date")
|
||||
|
||||
assert.Len(t, result, len(tt.expected))
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, r := range result {
|
||||
if r["date"] == exp["date"] {
|
||||
assert.Equal(t, exp["request_count"], r["request_count"])
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected result not found: %v", exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ ProviderService - isUniqueConstraintError 测试 ============
|
||||
|
||||
func TestProviderService_isUniqueConstraintError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "UNIQUE constraint failed",
|
||||
err: errors.New("UNIQUE constraint failed"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "duplicate key value",
|
||||
err: errors.New("duplicate key value"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "UNIQUE constraint case insensitive",
|
||||
err: errors.New("unique constraint violation"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
err: errors.New("some other error"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isUniqueConstraintError(tt.err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============ ProviderService - List API Key 测试 ============
|
||||
|
||||
func TestProviderService_List_APIKeyNotMasked(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
repo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewProviderService(repo, modelRepo, cache)
|
||||
|
||||
provider1 := &domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "sk-1234567890", BaseURL: "https://api.openai.com"}
|
||||
provider2 := &domain.Provider{ID: "anthropic", Name: "Anthropic", APIKey: "sk-anthropic1234", BaseURL: "https://api.anthropic.com"}
|
||||
require.NoError(t, svc.Create(provider1))
|
||||
require.NoError(t, svc.Create(provider2))
|
||||
|
||||
providers, err := svc.List()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 2)
|
||||
|
||||
expectedKeys := map[string]string{
|
||||
"openai": "sk-1234567890",
|
||||
"anthropic": "sk-anthropic1234",
|
||||
}
|
||||
for _, p := range providers {
|
||||
assert.NotContains(t, p.APIKey, "***")
|
||||
assert.Equal(t, expectedKeys[p.ID], p.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelService_ConcurrentCreate(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
cache := setupRoutingCache(t, db)
|
||||
svc := NewModelService(modelRepo, providerRepo, cache)
|
||||
|
||||
require.NoError(t, providerRepo.Create(&domain.Provider{ID: "openai", Name: "OpenAI", APIKey: "key", BaseURL: "https://api.openai.com"}))
|
||||
|
||||
results := make(chan error, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
go func() {
|
||||
model := &domain.Model{ProviderID: "openai", ModelName: "gpt-4"}
|
||||
results <- svc.Create(model)
|
||||
}()
|
||||
}
|
||||
|
||||
err1 := <-results
|
||||
err2 := <-results
|
||||
|
||||
successCount := 0
|
||||
errorCount := 0
|
||||
for _, err := range []error{err1, err2} {
|
||||
if err == nil {
|
||||
successCount++
|
||||
} else {
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, successCount)
|
||||
assert.Equal(t, 1, errorCount)
|
||||
}
|
||||
|
||||
197
backend/internal/service/stats_buffer.go
Normal file
197
backend/internal/service/stats_buffer.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type StatsBuffer struct {
|
||||
counters sync.Map
|
||||
|
||||
flushInterval time.Duration
|
||||
flushThreshold int
|
||||
totalCount atomic.Int64
|
||||
|
||||
statsRepo repository.StatsRepository
|
||||
logger *zap.Logger
|
||||
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
type StatsBufferOption func(*StatsBuffer)
|
||||
|
||||
func WithFlushInterval(d time.Duration) StatsBufferOption {
|
||||
return func(b *StatsBuffer) {
|
||||
b.flushInterval = d
|
||||
}
|
||||
}
|
||||
|
||||
func WithFlushThreshold(threshold int) StatsBufferOption {
|
||||
return func(b *StatsBuffer) {
|
||||
b.flushThreshold = threshold
|
||||
}
|
||||
}
|
||||
|
||||
func NewStatsBuffer(
|
||||
statsRepo repository.StatsRepository,
|
||||
logger *zap.Logger,
|
||||
opts ...StatsBufferOption,
|
||||
) *StatsBuffer {
|
||||
b := &StatsBuffer{
|
||||
statsRepo: statsRepo,
|
||||
logger: pkglogger.WithModule(logger, "service.stats_buffer"),
|
||||
flushInterval: 5 * time.Second,
|
||||
flushThreshold: 100,
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(b)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Increment(providerID, modelName string) {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
key := providerID + "/" + modelName + "/" + today
|
||||
|
||||
var counter *int64
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
if existing, ok := v.(*int64); ok {
|
||||
counter = existing
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
val := int64(0)
|
||||
counter = &val
|
||||
actual, loaded := b.counters.LoadOrStore(key, counter)
|
||||
if loaded {
|
||||
existing, ok := actual.(*int64)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter = existing
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddInt64(counter, 1)
|
||||
|
||||
if b.totalCount.Add(1) >= int64(b.flushThreshold) {
|
||||
go b.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Start() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(b.flushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
b.flush()
|
||||
case <-b.stopCh:
|
||||
b.flush()
|
||||
close(b.doneCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) Stop() {
|
||||
close(b.stopCh)
|
||||
<-b.doneCh
|
||||
}
|
||||
|
||||
func (b *StatsBuffer) flush() {
|
||||
type statEntry struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date string
|
||||
count int64
|
||||
}
|
||||
|
||||
var entries []statEntry
|
||||
b.counters.Range(func(key, value interface{}) bool {
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
parts := strings.Split(keyStr, "/")
|
||||
if len(parts) != 3 {
|
||||
return true
|
||||
}
|
||||
|
||||
counter, ok := value.(*int64)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
count := atomic.SwapInt64(counter, 0)
|
||||
|
||||
if count > 0 {
|
||||
entries = append(entries, statEntry{
|
||||
providerID: parts[0],
|
||||
modelName: parts[1],
|
||||
date: parts[2],
|
||||
count: count,
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
success := 0
|
||||
for _, entry := range entries {
|
||||
date, err := time.Parse("2006-01-02", entry.date)
|
||||
if err != nil {
|
||||
b.logger.Error("解析统计日期失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
zap.String("model_name", entry.modelName),
|
||||
zap.String("date", entry.date),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
|
||||
if err != nil {
|
||||
b.logger.Error("批量更新统计失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
zap.String("model_name", entry.modelName),
|
||||
zap.Int64("count", entry.count),
|
||||
zap.Error(err))
|
||||
|
||||
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter, ok := v.(*int64)
|
||||
if ok {
|
||||
atomic.AddInt64(counter, entry.count)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
|
||||
b.totalCount.Store(0)
|
||||
b.logger.Debug("统计刷新完成",
|
||||
zap.Int("total", len(entries)),
|
||||
zap.Int("success", success),
|
||||
zap.Int("failed", len(entries)-success))
|
||||
}
|
||||
261
backend/internal/service/stats_buffer_test.go
Normal file
261
backend/internal/service/stats_buffer_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type mockStatsRepo struct {
|
||||
records []struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date time.Time
|
||||
delta int
|
||||
}
|
||||
fail bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) Record(providerID, modelName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.fail {
|
||||
return errors.New("db error")
|
||||
}
|
||||
m.records = append(m.records, struct {
|
||||
providerID string
|
||||
modelName string
|
||||
date time.Time
|
||||
delta int
|
||||
}{providerID, modelName, date, delta})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatsRepo) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestStatsBuffer_Increment(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-3.5")
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count += atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(100), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_LoadOrStore(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
var counterCount int
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counterCount++
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, 1, counterCount)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FlushByInterval(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(100*time.Millisecond))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FlushByThreshold(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushThreshold(10))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_SwapInt64(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
var beforeCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
beforeCount = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), beforeCount)
|
||||
|
||||
buffer.flush()
|
||||
|
||||
var afterCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
afterCount = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(0), afterCount)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_FailRetry(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{fail: true}
|
||||
buffer := NewStatsBuffer(statsRepo, logger)
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
buffer.flush()
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestStatsBuffer_Stop(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(10*time.Second))
|
||||
|
||||
buffer.Start()
|
||||
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
|
||||
start := time.Now()
|
||||
buffer.Stop()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 1*time.Second)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
assert.GreaterOrEqual(t, len(statsRepo.records), 1)
|
||||
statsRepo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestStatsBuffer_ConcurrentIncrementAndFlush(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
statsRepo := &mockStatsRepo{}
|
||||
buffer := NewStatsBuffer(statsRepo, logger,
|
||||
WithFlushInterval(50*time.Millisecond))
|
||||
|
||||
buffer.Start()
|
||||
defer buffer.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer.Increment("openai", "gpt-4")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
statsRepo.mu.Lock()
|
||||
totalDelta := 0
|
||||
for _, r := range statsRepo.records {
|
||||
totalDelta += r.delta
|
||||
}
|
||||
statsRepo.mu.Unlock()
|
||||
|
||||
assert.Equal(t, 100, totalDelta)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user