1
0

46 Commits

Author SHA1 Message Date
235efb0e62 chore: 版本升迁 v0.1.1 2026-05-05 04:39:31 +08:00
6b1af27ea2 fix: 移除 version-bump 的工作区干净检查 2026-05-05 04:39:21 +08:00
32f48777f3 feat: make version-bump 默认 BUMP=patch,无需显式传参 2026-05-05 04:32:38 +08:00
bc7a7c6e81 feat: 迁移 versionctl 为独立模块并新增 make version-bump 命令
- 将 backend/cmd/versionctl 和 backend/pkg/projectversion 迁移至独立 versionctl/ Go 模块
- 新增 bump 子命令支持 major/minor/patch 和指定版本号,含版本倒退防护
- 新增 make version-bump 编排完整升迁流程(bump + sync + check + commit + tag)
- 更新所有引用路径:根 Makefile、backend/Makefile、release.yml、.golangci.yml
- 新增 versionctl/.golangci.yml(精简配置)和 Makefile(lint/test/coverage)
- 根 Makefile lint/test 集成 versionctl 模块
- 同步 openspec specs:新增 version-bump spec,更新 release-pipeline spec
2026-05-05 04:18:10 +08:00
3cd0458c2c fix: 修复 golangci-lint 报告的 gosec/gocyclo/forbidigo 问题 2026-05-05 03:35:20 +08:00
8eea30ea11 feat: 统一品牌标识、关于页面三卡片布局与版本诊断功能
- 统一品牌为 Nex:侧边栏、托盘 tooltip、HTML 标题、favicon (PNG 替代 SVG)
- 重构关于页面为三卡片布局(品牌/版本/链接),版本状态 Tag 绝对定位右上角
- 新增 GET /api/version 后端接口,返回 version/commit/build_time
- 新增前端版本一致性诊断:匹配/不匹配/不可判断三种状态
- 同步 delta specs 到主 specs 并归档变更
2026-05-05 03:28:22 +08:00
9e33e570af fix: 降低请求生命周期日志级别 2026-05-05 01:54:53 +08:00
7653385838 fix: 加固发布流水线运行环境
修复 Windows 发布作业在 MSYS2 环境下无法访问 Go 工具链的问题。

为三平台发布增加工具链预检并升级 release workflow 运行时兼容性,减少版本检查噪音和 CI 告警。
2026-05-05 01:27:38 +08:00
2c401f7ae6 chore: streamline workspace make workflows
Clarify product-level server and desktop commands while moving backend-only maintenance tasks into backend/Makefile. This keeps root automation focused on core flows and aligns the main OpenSpec specs with the new command boundaries.
2026-04-28 17:44:23 +08:00
a9972360c2 feat: 增加版本化构建与发布流程
引入 VERSION 作为统一版本源,避免前端、后端、桌面打包和发布资产之间的版本漂移。
新增 tag 驱动的 Draft Release 流程与版本化资产命名,使本地演进和 GitHub 发布共享同一套约束。
2026-04-28 14:20:27 +08:00
b00fa4dcee chore: 调整开源许可证为 Apache 2.0
新增根目录 Apache 2.0 许可证文件,并同步更新仓库文档与前端包元数据声明。
2026-04-27 11:24:33 +08:00
92525b39c3 chore: 将 assets 图标文件迁移到 Git LFS 2026-04-27 10:34:38 +08:00
38a2555c7b fix: Anthropic 流式编码器补全 message_start/message_delta 必填字段
跨协议流式转换时,Anthropic 客户端 Zod 校验因 SSE 事件缺少必填字段报错。
由 Anthropic encoder 层(而非 OpenAI decoder 层)负责补全协议默认值,保持权责分离。

- encodeMessageStart 补全 type/content/stop_reason/stop_sequence,usage nil 时输出零值
- encodeMessageDelta usage nil 时输出零值
- 更新相关测试覆盖新增行为
2026-04-26 23:27:34 +08:00
9622d44aac fix: 完善转换代理行为 2026-04-26 21:48:17 +08:00
155244433f fix: 完善桌面应用图标打包
统一 macOS 图标命名为 icon.icns。\n补充 Linux hicolor 图标资源。\n修复 Windows make 构建兼容性并为 exe 嵌入图标资源。\n清理旧版图标说明与不再使用的 SVG 源文件。
2026-04-26 12:24:03 +08:00
2c043c6cf7 fix: 修正 conversion 代理路径和错误边界 2026-04-25 23:12:54 +08:00
f5c82b6980 chore: 合并 dev-test-config 到 master 2026-04-24 23:23:12 +08:00
9105a36097 feat: 将"关于"从系统托盘原生对话框迁移到前端页面
移除系统托盘右键菜单中的"关于"选项及各平台原生对话框实现,
在前端新增 /about 路由和关于页面展示品牌信息,侧边栏增加关于导航入口
2026-04-24 23:17:22 +08:00
f1ee646ca4 fix: 修复 TestSaveAndLoadConfig 测试隔离问题,使用临时目录替代真实用户配置 2026-04-24 22:59:26 +08:00
b9b487c591 chore: 统一项目编辑器配置,移至仓库根目录 2026-04-24 22:28:01 +08:00
4c62c071fb fix: 修复 macOS 桌面应用打包与元数据
将 macOS 桌面应用改为通用二进制并动态写入最低系统版本,避免 Intel Mac 无法启动。统一桌面应用名称与托盘展示,并补充测试确保相关行为稳定。
2026-04-24 19:35:51 +08:00
b2e9dd8b7f refactor: 合并 macOS 桌面打包流程
将 macOS .app 打包直接并入 desktop-build-mac,减少重复的桌面构建入口。\n\n同时移除未实现或已废弃的 desktop-package-* 命令和独立打包脚本,降低维护成本。
2026-04-24 19:02:21 +08:00
d143c5f3df fix: 补齐前端生成物忽略并消除构建告警
统一 Git、ESLint、Prettier 对测试和构建生成物的忽略规则,避免本地产物导致 frontend-build 失败。

补齐表单 effect 依赖,移除无关告警,让前端构建链路恢复稳定。
2026-04-24 18:53:53 +08:00
4eebdfb8db chore: 合并 dev-code-format-frontend 到 master 2026-04-24 18:21:27 +08:00
b517946585 chore: 合并 dev-code-backend-format 到 master 2026-04-24 18:20:12 +08:00
4ddae6be74 refactor: 优化提示词文档,增强智能合并与规范审查能力 2026-04-24 18:12:23 +08:00
195762ff97 fix: 修复后端配置加载测试失败
- 修复 viper SafeWriteConfig 与 SetConfigFile 不兼容问题
  - 将 SafeWriteConfig() 替换为 SafeWriteConfigAs(configPath)
  - 绕过 viper 的 configPaths 检查
- 调整 Makefile 测试命令分类
  - backend-test: 仅运行后端核心测试
  - backend-test-all: 运行全部后端测试(含 desktop)
  - desktop-test: 单独运行桌面应用测试
- 同步 config-management 和 test-coverage 规范
2026-04-24 14:06:03 +08:00
bcf5ca89e5 refactor: Makefile 前端命令自动安装依赖 2026-04-24 13:50:51 +08:00
365943e4c4 feat: 前端集成 Prettier 代码格式化 2026-04-24 13:40:53 +08:00
4c6b49099d feat: 配置 golangci-lint 静态分析并修复存量违规
- 新增 backend/.golangci.yml 配置 12 个 linter(forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、goimports、gocyclo)
- 新增 lefthook.yml 配置 pre-commit hook 自动运行 lint
- 修复存量代码违规:errors.Is/As 替换、zap.Error 替换、import 排序、errcheck 修复
- 更新 README 补充编码规范说明
- 归档 backend-code-lint 变更
2026-04-24 13:01:48 +08:00
4c78ab6cc8 chore: 新增 backend-code-lint 变更计划,更新 .gitignore 2026-04-24 00:19:56 +08:00
52007c9461 feat: 前端 ESLint 规则增强,自动检测 LLM 编码违规
- 启用 TanStack Query flat/recommended(7 条规则)
- 新增 no-console(允许 warn/error)、consistent-type-imports(inline 风格)、no-non-null-assertion 规则
- 新增自定义规则 no-hardcoded-color-in-style,检测 JSX style 中硬编码颜色值
- 将 ESLint 检查集成到 build 命令(tsc -b && eslint . && vite build)
- 修复现有代码中的 lint 违规(import 顺序、type import 风格、unused vars)
- 使用 @typescript-eslint/rule-tester 编写自定义规则集成测试
2026-04-23 22:47:32 +08:00
086dd1fed7 refactor: 重构智能合并提示词,优化交互体验并增加语义审查能力
- 三阶段模型:计划审批(1次) → 自动合并(仅异常中断) → 汇总收尾(1次),减少冗余介入
- 新增语义审查:合并后自动分析代码冗余/过时模式/未集成基础设施/风格不一致
- 完善安全锚点体系:全局锚点+分支级锚点,分级回退机制
- 修复多个逻辑缺陷:stash含未跟踪文件、分支匹配过宽、远端分支拉取重名、语义修复交叉污染等
- 明确提问工具选项和交互方式,消除歧义引导
2026-04-23 22:35:24 +08:00
1d7e839b49 Merge branch 'dev-frontend-encrypt' into master 2026-04-23 18:42:42 +08:00
fa7babf13b chore: 归档 fix-windows-desktop-packaging 变更 2026-04-23 18:40:23 +08:00
280099b89c refactor: 后端日志系统重构
- 新增模块化日志器(pkg/logger/module.go)
- 新增 GORM 日志适配器
- 统一日志入口,移除所有 zap.L() 全局 logger 调用
- 字段标准化
- 启动阶段使用结构化日志
- 更新所有相关测试
2026-04-23 18:37:51 +08:00
0a92a25451 feat: 前端生产构建添加代码混淆
- 集成 vite-plugin-javascript-obfuscator 插件
- 配置中等偏高强度混淆策略(变量名、字符串、对象键、数字)
- 仅生产构建时启用,不影响开发体验
- 仅混淆业务代码,排除第三方库
- 不生成 Source Map
- 新增 frontend-obfuscation 规范
2026-04-23 18:23:07 +08:00
8c075194e5 fix: 修复合并后代码质量问题
- 修正 Makefile 迁移目录路径(sqlite3 → sqlite)
- 统一 database.go 日志风格(log.Printf → zapLogger)
- 修复 config.go validator 标签大小写
- 修复 database_test.go 测试使用 nil logger
- 移除未使用的 log 导入
2026-04-23 16:58:01 +08:00
53e477d383 Merge branch 'dev-mysql-support' into master
- 新增 MySQL 数据库驱动支持,支持跨设备数据同步
- 新增 MySQL 专项测试能力(并发、约束、迁移)
- 重构迁移目录结构:migrations/sqlite 和 migrations/mysql
- 修复 statsRepo 并发竞态条件,使用 upsert 保证原子性
- Makefile 合并:保留完整命令体系 + 新增 MySQL 测试命令
2026-04-23 16:31:29 +08:00
1522c87c74 fix: 修复 statsRepo 并发竞态条件,使用 upsert 保证原子性
- 使用 GORM clause.OnConflict 替代事务包装
- Record 和 BatchUpdate 方法改用 upsert 模式
- 修复 UsageStats 的 GORM struct tag,确保 AutoMigrate 创建正确的 UNIQUE 约束
- 更新 usage-statistics spec 以反映 upsert 操作

MySQL 并发测试验证:10 并发调用 → request_count = 10
2026-04-23 15:54:56 +08:00
e0d05c9869 refactor: Makefile 命名规范化,新增顶层便捷命令
统一命名规范为 <namespace>-<action>[-<variant>] 格式:
- 重命名 desktop-mac/win/linux → desktop-build-mac/win/linux
- 重命名 backend-migrate-* → backend-db-*
- 重命名 frontend-build-desktop → desktop-prepare-frontend
- 重命名 embedfs-prepare → desktop-prepare-embedfs
- 重命名 package-macos → desktop-package-mac

新增顶层便捷命令:
- dev: 并行启动开发环境
- build: 构建所有产物
- test: 运行所有测试
- lint: 检查所有代码
- clean: 清理所有构建产物
2026-04-23 12:30:02 +08:00
5b401e29cb feat: 新增 MySQL 专项测试能力
- 新增 backend/tests/mysql/ 目录,包含 Docker Compose 配置和测试文件
- 新增 Makefile 命令: test-mysql, test-mysql-up, test-mysql-down, test-mysql-quick
- 使用 build tag 控制测试启用,默认不运行
- 测试覆盖: 迁移正确性、外键约束、UNIQUE 约束、并发写入
- 发现 statsRepo.Record 存在并发 bug(检查-然后-操作竞态条件)
2026-04-23 12:25:55 +08:00
65ac9f740a refactor: 桌面应用对话框代码拆分为平台专用文件
- 新增 dialog_windows.go、dialog_darwin.go、dialog_linux.go
- 使用 Go 构建标签实现条件编译
- 修复跨平台编译错误(syscall.NewLazyDLL 在 macOS/Linux 未定义)
- 实现 Linux 多工具降级策略(zenity → kdialog → notify-send → xmessage → stderr)
- 实现 macOS AppleScript 字符转义
- 更新 messagebox_test.go 构建标签
- 更新 desktop-app spec 新增 Linux 降级策略和 macOS 字符转义规范
2026-04-23 11:47:48 +08:00
58ebcaa299 refactor(scripts): 开发分支初始化脚本从 bash 重构为 Python,增强跨平台支持和用户体验 2026-04-23 10:43:59 +08:00
b3258e76df perf: 前端打包产物优化——路由级懒加载和 vendor 分包
- 使用 React.lazy() + Suspense 实现路由级代码分割
- 配置 manualChunks 将 react/tdesign/recharts 拆分为独立 vendor chunk
- 页面组件改为 export default 以支持动态导入
- 新增 bundle-optimization 规范,更新 frontend 导航规范
2026-04-23 00:26:54 +08:00
64dc66afa6 fix: Windows 桌面应用打包问题修复
- 删除通用 desktop target,重命名 platform targets 为简短形式 (desktop-mac/win/linux)
- 构建产物文件名统一为 nex-{os}-{arch}[.exe] 格式
- Windows 托盘图标使用 .ico 格式(运行时按平台选择)
- Windows 原生对话框使用 user32.MessageBoxW 替代 msg * 命令
- 更新 README.md 和 package-macos.sh 中的引用
- 添加单元测试覆盖 MessageBoxW 封装和图标选择逻辑
- 同步更新 desktop-app spec 规范文档
2026-04-22 23:20:39 +08:00
272 changed files with 13923 additions and 3912 deletions

7
.editorconfig Normal file
View File

@@ -0,0 +1,7 @@
root = true
[*]
end_of_line = lf
trim_trailing_whitespace = true
insert_final_newline = true
charset = utf-8

9
.gitattributes vendored Normal file
View File

@@ -0,0 +1,9 @@
* text=auto eol=lf
assets/*.png filter=lfs diff=lfs merge=lfs -text
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
assets/*.icns filter=lfs diff=lfs merge=lfs -text
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
assets/*.ico filter=lfs diff=lfs merge=lfs -text
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text
frontend/public/*.png filter=lfs diff=lfs merge=lfs -text
frontend/public/**/*.png filter=lfs diff=lfs merge=lfs -text

210
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,210 @@
name: Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: read
jobs:
prepare:
name: Prepare Release
runs-on: ubuntu-latest
permissions:
contents: read
outputs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Verify tag and VERSION
id: version
run: |
version=$(go run ./versionctl print)
go run ./versionctl verify-tag "${GITHUB_REF_NAME}"
printf 'version=%s\n' "$version" >> "$GITHUB_OUTPUT"
build-linux:
name: Build Linux Assets
needs: prepare
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Install Linux desktop build dependencies
run: |
sudo apt-get update
sudo apt-get install -y libayatana-appindicator3-dev libgtk-3-dev
- name: Preflight Linux release toolchain
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v gcc
gcc --version
command -v pkg-config
pkg-config --modversion ayatana-appindicator3-0.1
pkg-config --modversion gtk+-3.0
- name: Build Linux release assets
run: make release-assets-linux
- name: Upload Linux release assets
uses: actions/upload-artifact@v4
with:
name: release-linux
path: build/release/*
build-windows:
name: Build Windows Assets
needs: prepare
runs-on: windows-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Setup MSYS2 toolchain
uses: msys2/setup-msys2@v2
with:
msystem: MINGW64
path-type: inherit
update: true
install: >-
make
mingw-w64-x86_64-gcc
- name: Preflight Windows release toolchain
shell: msys2 {0}
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v make
make --version
command -v gcc
gcc --version
command -v windres
windres --version
if command -v powershell.exe >/dev/null 2>&1; then
powershell.exe -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
else
command -v powershell
powershell -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
fi
- name: Build Windows release assets
shell: msys2 {0}
run: make release-assets-windows
- name: Upload Windows release assets
uses: actions/upload-artifact@v4
with:
name: release-windows
path: build/release/*
build-macos:
name: Build macOS Assets
needs: prepare
runs-on: macos-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Preflight macOS release toolchain
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v ditto
xcrun --find lipo
xcrun --find vtool
- name: Build macOS release assets
run: make release-assets-macos
- name: Upload macOS release assets
uses: actions/upload-artifact@v4
with:
name: release-macos
path: build/release/*
draft-release:
name: Create Draft Release
needs: [prepare, build-linux, build-windows, build-macos]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download release assets
uses: actions/download-artifact@v4
with:
pattern: release-*
merge-multiple: true
path: dist
- name: Publish draft release
uses: softprops/action-gh-release@v2
with:
name: v${{ needs.prepare.outputs.version }}
tag_name: ${{ github.ref_name }}
draft: true
files: |
dist/*

3
.gitignore vendored
View File

@@ -401,13 +401,16 @@ cython_debug/
# Custom # Custom
.claude .claude
.opencode .opencode
.codex
openspec/changes/archive openspec/changes/archive
temp temp
.agents .agents
skills-lock.json skills-lock.json
.worktrees .worktrees
!scripts/build/ !scripts/build/
backend/bin
# Embedfs generated # Embedfs generated
embedfs/assets/ embedfs/assets/
embedfs/frontend-dist/ embedfs/frontend-dist/
backend/cmd/desktop/rsrc_windows_*.syso

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"files.eol": "\n"
}

184
LICENSE Normal file
View File

@@ -0,0 +1,184 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form,
made available under the License, as indicated by a copyright notice that is
included in or attached to the work (an example is provided in the Appendix
below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this
License, each Contributor hereby grants to You a perpetual, worldwide,
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
reproduce, prepare Derivative Works of, publicly display, publicly perform,
sublicense, and distribute the Work and such Derivative Works in Source or
Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License,
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section) patent
license to make, have made, use, offer to sell, sell, import, and otherwise
transfer the Work, where such license applies only to those patent claims
licensable by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s) with the Work
to which such Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a lawsuit)
alleging that the Work or a Contribution incorporated within the Work
constitutes direct or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate as of the date
such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or
Derivative Works thereof in any medium, with or without modifications, and in
Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of
this License; and
(b) You must cause any modified files to carry prominent notices stating that
You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You
distribute, all copyright, patent, trademark, and attribution notices from the
Source form of the Work, excluding those notices that do not pertain to any part
of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
any Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any
Contribution intentionally submitted for inclusion in the Work by You to the
Licensor shall be under the terms and conditions of this License, without any
additional terms or conditions. Notwithstanding the above, nothing herein shall
supersede or modify the terms of any separate license agreement you may have
executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names,
trademarks, service marks, or product names of the Licensor, except as required
for reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
writing, Licensor provides the Work (and each Contributor provides its
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied, including, without limitation, any warranties
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any risks
associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in
tort (including negligence), contract, or otherwise, unless required by
applicable law (such as deliberate and grossly negligent acts) or agreed to in
writing, shall any Contributor be liable to You for damages, including any
direct, indirect, special, incidental, or consequential damages of any character
arising as a result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill, work stoppage,
computer failure or malfunction, or any and all other commercial damages or
losses), even if such Contributor has been advised of the possibility of such
damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or
Derivative Works thereof, You may choose to offer, and charge a fee for,
acceptance of support, warranty, indemnity, or other liability obligations
and/or rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if You agree to
indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

318
Makefile
View File

@@ -1,119 +1,251 @@
.PHONY: all clean \ .PHONY: \
backend-build backend-run backend-test backend-test-unit backend-test-integration backend-test-coverage \ lint test clean \
backend-lint backend-deps backend-generate \ version-sync version-check version-bump \
backend-migrate-up backend-migrate-down backend-migrate-status backend-migrate-create \ server-run server-build server-lint server-test server-clean \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint \ desktop-build-mac desktop-build-win desktop-build-linux \
desktop desktop-darwin desktop-windows desktop-linux package-macos desktop-lint desktop-test desktop-clean \
release-assets-linux release-assets-windows release-assets-macos \
_backend-lint _backend-test _backend-clean _backend-build \
_versionctl-lint _versionctl-test \
_frontend-install _frontend-build _frontend-check _frontend-test _frontend-dev _frontend-clean \
_desktop-test _desktop-clean _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource \
_server-run-backend _server-run-frontend
# Delay shell lookups until a target needs them, then cache the result for this make run.
lazy_shell = $(or $($(1)),$(eval $(1) := $(shell $(2)))$($(1)))
VERSION = $(call lazy_shell,_VERSION,go run ./versionctl print)
GIT_COMMIT ?= $(call lazy_shell,_GIT_COMMIT,git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
BUILD_TIME ?= $(call lazy_shell,_BUILD_TIME,date -u +"%Y-%m-%dT%H:%M:%SZ")
GO_LDFLAGS = -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
GO_LDFLAGS_WIN = $(GO_LDFLAGS) -H=windowsgui
RELEASE_DIR := build/release
SERVER_LINUX_ASSET = $(call lazy_shell,_SERVER_LINUX_ASSET,go run ./versionctl asset-name server linux amd64)
SERVER_WINDOWS_ASSET = $(call lazy_shell,_SERVER_WINDOWS_ASSET,go run ./versionctl asset-name server windows amd64)
SERVER_DARWIN_AMD64_ASSET = $(call lazy_shell,_SERVER_DARWIN_AMD64_ASSET,go run ./versionctl asset-name server darwin amd64)
SERVER_DARWIN_ARM64_ASSET = $(call lazy_shell,_SERVER_DARWIN_ARM64_ASSET,go run ./versionctl asset-name server darwin arm64)
DESKTOP_LINUX_ASSET = $(call lazy_shell,_DESKTOP_LINUX_ASSET,go run ./versionctl asset-name desktop linux)
DESKTOP_WINDOWS_ASSET = $(call lazy_shell,_DESKTOP_WINDOWS_ASSET,go run ./versionctl asset-name desktop windows)
DESKTOP_MACOS_ASSET = $(call lazy_shell,_DESKTOP_MACOS_ASSET,go run ./versionctl asset-name desktop macos)
# ============================================ # ============================================
# 后端 # 全局命令
# ============================================ # ============================================
all: backend-build lint: _backend-lint _frontend-check _versionctl-lint
@printf 'Lint complete\n'
backend-build: test: _backend-test _frontend-test _desktop-test _versionctl-test
cd backend && go build -o bin/server ./cmd/server @printf 'All tests passed\n'
backend-run: clean: _backend-clean _frontend-clean _desktop-clean
cd backend && go run ./cmd/server @printf 'Clean complete\n'
backend-test:
cd backend && go test ./... -v
backend-test-unit:
cd backend && go test ./internal/... ./pkg/... -v
backend-test-integration:
cd backend && go test ./tests/... -v
backend-test-coverage:
cd backend && go test ./... -coverprofile=coverage.out
cd backend && go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: backend/coverage.html"
backend-lint:
cd backend && go tool golangci-lint run ./...
backend-deps:
cd backend && go mod tidy
backend-generate:
cd backend && go generate ./...
DB_DRIVER ?= sqlite3
DB_DSN ?= $(DB_PATH)
backend-migrate-up:
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" up
backend-migrate-down:
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" down
backend-migrate-status:
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" status
backend-migrate-create:
@read -p "Migration name: " name; \
cd backend && goose -dir migrations/sqlite create $$name sql; \
cd backend && goose -dir migrations/mysql create $$name sql
# ============================================ # ============================================
# 前端 # 版本管理
# ============================================ # ============================================
frontend-build: version-sync:
cd frontend && bun install && bun run build go run ./versionctl sync
frontend-dev: version-check:
cd frontend && bun dev go run ./versionctl check
frontend-test: version-bump: BUMP ?= patch
cd frontend && bun run test version-bump:
$(eval _BUMP_ARG := $(if $(SET_VERSION),$(SET_VERSION),$(BUMP)))
frontend-test-watch: $(eval _NEW_VERSION := $(shell go run ./versionctl bump $(_BUMP_ARG)))
cd frontend && bun run test:watch git add VERSION frontend/
git commit -m "chore: 版本升迁 v$(_NEW_VERSION)"
frontend-test-coverage: git tag "v$(_NEW_VERSION)"
cd frontend && bun run test:coverage @printf '版本升迁完成: v%s\n' "$(_NEW_VERSION)"
frontend-test-e2e:
cd frontend && bun run test:e2e
frontend-lint:
cd frontend && bun run lint
# ============================================ # ============================================
# 桌面应用 # Server 模式
# ============================================ # ============================================
desktop: frontend-build-desktop embedfs-prepare server-run:
cd backend && CGO_ENABLED=1 go build -o ../build/nex ./cmd/desktop @$(MAKE) -j2 _server-run-backend _server-run-frontend
frontend-build-desktop: server-build: version-check _backend-build _frontend-build
cd frontend && cp .env.desktop .env.production.local && bun install && bun run build && rm -f .env.production.local @printf 'Server build complete\n'
embedfs-prepare: server-lint: _backend-lint _frontend-check
@printf 'Server lint complete\n'
server-test: _backend-test _frontend-test
@printf 'Server tests passed\n'
server-clean: _backend-clean _frontend-clean
@printf 'Server artifacts cleaned\n'
_server-run-backend:
@$(MAKE) -C backend run
_server-run-frontend: _frontend-install
cd frontend && bun run dev
# ============================================
# Desktop 模式
# ============================================
desktop-build-mac: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
@printf 'Building macOS desktop...\n'
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-arm64 ./cmd/desktop
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-amd64 ./cmd/desktop
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
@printf 'Packaging macOS app bundle...\n'
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
@if [ -f assets/icon.icns ]; then \
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
else \
printf 'Missing assets/icon.icns\n'; \
fi
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
if [ -z "$$MIN_MACOS_VERSION" ]; then \
printf 'Unable to read macOS minimum version\n'; \
exit 1; \
fi; \
go run ./versionctl macos-plist "$$MIN_MACOS_VERSION" > build/Nex.app/Contents/Info.plist
chmod +x build/Nex.app/Contents/MacOS/nex
@printf 'macOS desktop build complete\n'
desktop-build-win: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource
@printf 'Building Windows desktop...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "New-Item -ItemType Directory -Path 'build' -Force | Out-Null"
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
else
mkdir -p build
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
endif
@printf 'Windows desktop build complete\n'
desktop-build-linux: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
@printf 'Building Linux desktop...\n'
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-linux-amd64 ./cmd/desktop
@printf 'Linux desktop build complete\n'
desktop-lint: _backend-lint _frontend-check
@printf 'Desktop lint complete\n'
desktop-test: _desktop-test
@printf 'Desktop tests passed\n'
desktop-clean: _desktop-clean
@printf 'Desktop artifacts cleaned\n'
_desktop-test:
cd backend && go test ./cmd/desktop/... -v
_desktop-clean:
rm -rf build/ embedfs/assets embedfs/frontend-dist backend/cmd/desktop/rsrc_windows_amd64.syso
_desktop-prepare-frontend: _frontend-install
@printf 'Preparing frontend for desktop...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Copy-Item -LiteralPath 'frontend/.env.desktop' -Destination 'frontend/.env.production.local' -Force"
cd frontend && bun run build
powershell -NoProfile -Command "Remove-Item -LiteralPath 'frontend/.env.production.local' -Force -ErrorAction SilentlyContinue"
else
cd frontend && cp .env.desktop .env.production.local
cd frontend && bun run build
rm -f frontend/.env.production.local
endif
_desktop-prepare-embedfs:
@printf 'Preparing embedded filesystem...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Remove-Item -LiteralPath 'embedfs/assets' -Recurse -Force -ErrorAction SilentlyContinue; Remove-Item -LiteralPath 'embedfs/frontend-dist' -Recurse -Force -ErrorAction SilentlyContinue; Copy-Item -LiteralPath 'assets' -Destination 'embedfs/assets' -Recurse; Copy-Item -LiteralPath 'frontend/dist' -Destination 'embedfs/frontend-dist' -Recurse"
else
rm -rf embedfs/assets embedfs/frontend-dist rm -rf embedfs/assets embedfs/frontend-dist
cp -r assets embedfs/assets cp -r assets embedfs/assets
cp -r frontend/dist embedfs/frontend-dist cp -r frontend/dist embedfs/frontend-dist
endif
desktop-darwin: frontend-build-desktop embedfs-prepare _desktop-prepare-windows-resource:
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-darwin-arm64 ./cmd/desktop @printf 'Preparing Windows executable icon...\n'
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-darwin-amd64 ./cmd/desktop ifeq ($(OS),Windows_NT)
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso
desktop-windows: frontend-build-desktop embedfs-prepare else
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-windows-amd64.exe ./cmd/desktop @if command -v x86_64-w64-mingw32-windres >/dev/null 2>&1; then \
cd backend/cmd/desktop && x86_64-w64-mingw32-windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
desktop-linux: frontend-build-desktop embedfs-prepare elif command -v windres >/dev/null 2>&1; then \
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
else \
package-macos: printf 'Missing windres for Windows icon resource generation\n'; \
./scripts/build/package-macos.sh exit 1; \
fi
endif
# ============================================ # ============================================
# 清理 # 发布资产
# ============================================ # ============================================
clean: release-assets-linux: version-check desktop-build-linux
rm -rf backend/bin/ backend/coverage.out backend/coverage.html rm -rf "$(RELEASE_DIR)"
rm -rf build/ mkdir -p "$(RELEASE_DIR)"
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-linux-amd64 ./cmd/server
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_LINUX_ASSET)" nex-server-linux-amd64
tar -C build -czf "$(RELEASE_DIR)/$(DESKTOP_LINUX_ASSET)" nex-linux-amd64
release-assets-windows: version-check desktop-build-win
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Remove-Item -LiteralPath '$(RELEASE_DIR)' -Recurse -Force -ErrorAction SilentlyContinue; New-Item -ItemType Directory -Path '$(RELEASE_DIR)' -Force | Out-Null"
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-server-win-amd64.exe ./cmd/server
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-server-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(SERVER_WINDOWS_ASSET)' -Force"
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(DESKTOP_WINDOWS_ASSET)' -Force"
else
@printf 'release-assets-windows requires Windows\n'
@exit 1
endif
release-assets-macos: version-check desktop-build-mac
rm -rf "$(RELEASE_DIR)"
mkdir -p "$(RELEASE_DIR)"
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-amd64 ./cmd/server
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-arm64 ./cmd/server
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_AMD64_ASSET)" nex-server-darwin-amd64
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_ARM64_ASSET)" nex-server-darwin-arm64
ditto -c -k --keepParent build/Nex.app "$(RELEASE_DIR)/$(DESKTOP_MACOS_ASSET)"
# ============================================
# 共享 helper targets
# ============================================
_backend-build:
@$(MAKE) -C backend build
_backend-lint:
@$(MAKE) -C backend lint
_backend-test:
@$(MAKE) -C backend test
_backend-clean:
@$(MAKE) -C backend clean
_versionctl-lint:
@$(MAKE) -C versionctl lint
_versionctl-test:
@$(MAKE) -C versionctl test
_frontend-install:
cd frontend && bun install
_frontend-build: _frontend-install
cd frontend && bun run build
_frontend-check: _frontend-install
cd frontend && bun run check
_frontend-test: _frontend-install
cd frontend && bun run test
_frontend-dev: _frontend-install
cd frontend && bun run dev
_frontend-clean:
rm -rf frontend/dist frontend/.next frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo

184
README.md
View File

@@ -36,13 +36,9 @@ nex/
├── assets/ # 应用资源 ├── assets/ # 应用资源
│ ├── icon.png # 托盘图标 │ ├── icon.png # 托盘图标
│ ├── AppIcon.icns # macOS 应用图标 │ ├── icon.icns # macOS 应用图标
│ └── icon.ico # Windows 应用图标 │ └── icon.ico # Windows 应用图标
├── scripts/ # 构建脚本
│ └── build/
│ └── package-macos.sh # macOS .app 打包脚本
└── README.md # 本文件 └── README.md # 本文件
``` ```
@@ -51,7 +47,7 @@ nex/
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议 - **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
- **跨协议转换**Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换 - **跨协议转换**Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
- **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4` - **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段 - **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换 - **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
- **Function Calling**支持工具调用Tools - **Function Calling**支持工具调用Tools
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置 - **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
@@ -67,11 +63,25 @@ nex/
- **HTTP 框架**: Gin - **HTTP 框架**: Gin
- **ORM**: GORM - **ORM**: GORM
- **数据库**: SQLite / MySQL - **数据库**: SQLite / MySQL
- **日志**: zap + lumberjack结构化日志 + 日志轮转) - **日志**: zap + lumberjack结构化日志 + 日志轮转 + 模块标识
- **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值) - **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值)
- **验证**: go-playground/validator/v10 - **验证**: go-playground/validator/v10
- **迁移**: goose - **迁移**: goose
#### 日志模块标识规范
每个模块通过依赖注入获取带模块标识的 logger日志输出格式为 `[module.name]`
```
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
```
模块命名规范:
- 单一职责包:`database``config`
- 多实体包:`handler.proxy``service.provider`
- 子包:`handler.middleware`
### 前端 ### 前端
- **运行时**: Bun - **运行时**: Bun
- **构建工具**: Vite - **构建工具**: Vite
@@ -81,7 +91,7 @@ nex/
- **图表库**: Recharts - **图表库**: Recharts
- **路由**: React Router v7 - **路由**: React Router v7
- **数据获取**: TanStack Query v5 - **数据获取**: TanStack Query v5
- **样式**: SCSS Modules - **样式**: TDesign 组件 props 优先TDesign tokens 次之SCSS 作为兜底补充
- **测试**: Vitest + React Testing Library + Playwright - **测试**: Vitest + React Testing Library + Playwright
## 快速开始 ## 快速开始
@@ -91,22 +101,18 @@ nex/
**构建桌面应用** **构建桌面应用**
```bash ```bash
# 当前平台 # macOS (arm64 + amd64并打包为 .app)
make desktop make desktop-build-mac
# macOS (arm64 + amd64)
make desktop-darwin
make package-macos # 打包为 .app
# Windows # Windows
make desktop-windows make desktop-build-win
# Linux # Linux
make desktop-linux make desktop-build-linux
``` ```
**使用桌面应用** **使用桌面应用**
- 双击启动应用macOS: Nex.appWindows: nex.exeLinux: nex - 双击启动应用macOS: Nex.appWindows: nex-win-amd64.exeLinux: nex-linux-amd64
- 系统托盘图标出现,浏览器自动打开管理界面 - 系统托盘图标出现,浏览器自动打开管理界面
- 点击托盘图标显示菜单,可打开管理界面或退出 - 点击托盘图标显示菜单,可打开管理界面或退出
- 关闭浏览器后服务继续运行,可通过托盘重新打开 - 关闭浏览器后服务继续运行,可通过托盘重新打开
@@ -123,50 +129,54 @@ make desktop-linux
- Xfce: 需要 libappindicator - Xfce: 需要 libappindicator
- 其他支持 StatusNotifierItem 规范的环境 - 其他支持 StatusNotifierItem 规范的环境
### CLI 模式 ### Server 模式(前后端分离)
#### 后端
```bash ```bash
cd backend make server-run
go mod download
go run cmd/server/main.go
``` ```
后端服务将在 `http://localhost:9826` 启动。首次启动会自动: `make server-run` 会并行启动:
- 后端服务:`http://localhost:9826`
- 前端开发服务器:`http://localhost:5173`
前端请求会继续通过 Vite proxy 转发到后端。后端首次启动会自动:
- 创建配置文件 `~/.nex/config.yaml` - 创建配置文件 `~/.nex/config.yaml`
- 初始化数据库 `~/.nex/config.db` - 初始化数据库 `~/.nex/config.db`
- 运行数据库迁移 - 运行数据库迁移
- 创建日志目录 `~/.nex/log/` - 创建日志目录 `~/.nex/log/`
### 前端 **构建 server 模式产物**
```bash ```bash
cd frontend make server-build
bun install
bun dev
``` ```
前端开发服务器将在 `http://localhost:5173` 启动API 请求通过 Vite proxy 转发到后端。
## API 接口 ## API 接口
### 代理接口(对外部应用) ### 代理接口(对外部应用)
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。 代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
**OpenAI 协议**`protocol=openai` **OpenAI 协议**`protocol=openai`
- `POST /openai/chat/completions` - 对话补全 - `POST /openai/v1/chat/completions` - 对话补全
- `GET /openai/models` - 模型列表(本地数据库聚合) - `GET /openai/v1/models` - 模型列表(本地数据库聚合)
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询) - `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
- `POST /openai/embeddings` - 嵌入 - `POST /openai/v1/embeddings` - 嵌入
- `POST /openai/rerank` - 重排序 - `POST /openai/v1/rerank` - 重排序
**Anthropic 协议**`protocol=anthropic` **Anthropic 协议**`protocol=anthropic`
- `POST /anthropic/v1/messages` - 消息对话 - `POST /anthropic/v1/messages` - 消息对话
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合) - `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询) - `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions``/v1/models``/v1/embeddings``/v1/rerank`,并在构建上游 URL 时去掉 `/v1`Anthropic adapter 接收 `/v1/messages``/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON``MODEL_NOT_FOUND``CONVERSION_FAILED``UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
### 管理接口(对前端) ### 管理接口(对前端)
#### 供应商管理 #### 供应商管理
@@ -189,6 +199,9 @@ bun dev
查询参数支持:`provider_id``model_name``start_date``end_date``group_by` 查询参数支持:`provider_id``model_name``start_date``end_date``group_by`
#### 版本信息
- `GET /api/version` - 获取后端构建版本信息(`version``commit``build_time`),用于前端 About 页面诊断前后端版本一致性
## 配置 ## 配置
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值** 配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
@@ -262,25 +275,100 @@ export NEX_DATABASE_DBNAME=nex
## 测试 ## 测试
```bash ```bash
make backend-test # 后端测试 # 全局默认测试(不含 MySQL 和前端 E2E
make backend-test-coverage # 后端覆盖率 make test
make frontend-test # 前端测试
make frontend-test-e2e # 前端 E2E 测试 # 产品级测试
make server-test
make desktop-test
``` ```
backend 分类测试、MySQL 专项测试和前端 E2E 测试请分别查看 `backend/README.md``frontend/README.md`
## 开发 ## 开发
```bash ```bash
make backend-build # 构建后端 # 首次克隆后安装 Git hooks
make backend-run # 运行后端 lefthook install
make backend-lint # 后端代码检查
make backend-migrate-up # 数据库迁移
make frontend-build # 构建前端 # 全局命令
make frontend-dev # 前端开发模式 make lint # 前后端共享检查
make frontend-lint # 前端代码检查 make test # 默认全量测试(不含 MySQL/E2E
make clean # 清理所有构建产物和测试报告
# server 模式
make server-run # 并行启动后端和前端开发服务
make server-build # 构建 backend/bin/server 和 frontend/dist
make server-lint # server 模式检查
make server-test # server 模式测试
make server-clean # 清理 server 模式产物
# desktop 模式
make desktop-build-mac # 构建 macOS 桌面应用
make desktop-build-win # 构建 Windows 桌面应用
make desktop-build-linux # 构建 Linux 桌面应用
make desktop-lint # desktop 模式检查
make desktop-test # desktop 专属测试
make desktop-clean # 清理 desktop 产物
``` ```
## 版本与发布
### 统一版本源
- 仓库根目录 `VERSION` 是全仓唯一版本源,格式固定为 `x.y.z`
- `frontend/package.json` 和前端 `.env.*` 中的 `VITE_APP_VERSION` 由仓库工具同步,不能手工漂移
### 本地版本演进
```bash
# 递增版本(自动 sync + check + commit + tag
make version-bump BUMP=minor
# 或指定具体版本号
make version-bump SET_VERSION=1.0.0
# 推送到远程
git push --follow-tags
```
手动同步和校验:
```bash
make version-sync
make version-check
```
### 本地生成发布资产
```bash
# Linux: server + desktop
make release-assets-linux
# Windows: server + desktop需在 Windows 环境执行)
make release-assets-windows
# macOS: darwin-amd64 server、darwin-arm64 server、desktop universal
make release-assets-macos
```
生成的版本化发布资产位于 `build/release/`
### GitHub Draft Release
- 推送 `vX.Y.Z` tag 后,`.github/workflows/release.yml` 会自动执行发布流水线
- 三个平台 job 会在正式构建前先检查 `go``bun` 和各自的平台打包工具链,缺失时快速失败并在日志中输出诊断信息
- Windows 发布 job 在 `MSYS2 / MINGW64` shell 中执行,并继承 `setup-go` / `setup-bun` 准备好的工具链路径
- 流水线会先校验 tag 与 `VERSION` 一致,再构建以下资产并上传到 GitHub Draft Release
- Linux server
- Windows server
- darwin-amd64 server
- darwin-arm64 server
- Linux desktop
- Windows desktop
- macOS desktop universal
- Release 默认以 Draft 形式创建,需人工检查后再公开发布
## 开发规范 ## 开发规范
详见各子项目的 README.md 详见各子项目的 README.md
@@ -289,4 +377,4 @@ make frontend-lint # 前端代码检查
## 许可证 ## 许可证
MIT Apache License 2.0

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.1.1

Binary file not shown.

View File

@@ -1,64 +0,0 @@
# Assets
应用资源文件目录。
## 文件说明
| 文件 | 用途 | 尺寸 | 格式 |
|------|------|------|------|
| `icon.svg` | 源图标 | 64x64 | SVG |
| `icon.png` | 托盘图标 | 64x64 | PNG |
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
## 替换图标
### 1. 准备图标
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
### 2. 生成各平台图标
**托盘图标 (PNG)**
```bash
magick your-icon.svg -resize 64x64 icon.png
```
**macOS 应用图标 (ICNS)**
```bash
mkdir icon.iconset
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
iconutil -c icns icon.iconset -o AppIcon.icns
rm -rf icon.iconset
```
**Windows 应用图标 (ICO)**
```bash
magick your-icon.svg -resize 256x256 icon.ico
```
### 3. 替换文件
将生成的文件放入此目录,然后重新构建桌面应用:
```bash
./scripts/build/build-darwin-arm64.sh
```
## macOS Template 图标
macOS 支持 Template 图标,自动适配深浅色模式:
- 使用黑色 + 透明设计
- 文件名以 `Template` 结尾(如 `iconTemplate.png`
- 黑色在深色模式下自动变为白色
## 设计建议
- 托盘图标应简洁,在小尺寸下清晰可辨
- 避免过多细节和文字
- 使用高对比度颜色
- macOS 建议使用 Template 图标风格

BIN
assets/icon.icns LFS Normal file

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 264 KiB

After

Width:  |  Height:  |  Size: 130 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

After

Width:  |  Height:  |  Size: 131 B

View File

@@ -1,13 +0,0 @@
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
<circle cx="32" cy="32" r="6" fill="white"/>
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
<circle cx="20" cy="20" r="3" fill="white"/>
<circle cx="44" cy="20" r="3" fill="white"/>
<circle cx="20" cy="44" r="3" fill="white"/>
<circle cx="44" cy="44" r="3" fill="white"/>
</svg>

Before

Width:  |  Height:  |  Size: 779 B

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

97
backend/Makefile Normal file
View File

@@ -0,0 +1,97 @@
.PHONY: \
build run \
test test-unit test-integration test-coverage \
lint clean \
migrate-up migrate-down migrate-status migrate-create \
mysql-up mysql-down mysql-test mysql-test-quick
VERSION := $(shell go run ../versionctl print)
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
DB_DRIVER ?= sqlite3
DB_DSN ?= $(HOME)/.nex/config.db
ifeq ($(DB_DRIVER),mysql)
GOOSE_DIR := migrations/mysql
GOOSE_DRIVER := mysql
else ifeq ($(DB_DRIVER),sqlite3)
GOOSE_DIR := migrations/sqlite
GOOSE_DRIVER := sqlite3
else
$(error unsupported DB_DRIVER '$(DB_DRIVER)', use sqlite3 or mysql)
endif
build:
go build -ldflags "$(GO_LDFLAGS)" -o bin/server ./cmd/server
run:
go run -ldflags "$(GO_LDFLAGS)" ./cmd/server
test:
go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
test-unit:
go test ./internal/... ./pkg/... -v
test-integration:
go test ./tests/... -v
test-coverage:
go test ./... -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
@printf 'Coverage report generated: backend/coverage.html\n'
lint:
go tool golangci-lint run ./...
clean:
rm -rf bin/ coverage.out coverage.html
migrate-up:
@printf 'Running database migration up...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" up
migrate-down:
@printf 'Running database migration down...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" down
migrate-status:
@printf 'Checking database migration status...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" status
migrate-create:
@printf 'Migration name: '; \
read name; \
goose -dir migrations/sqlite create $$name sql; \
goose -dir migrations/mysql create $$name sql
mysql-up:
@printf 'Starting MySQL test container...\n'
cd tests/mysql && docker-compose up -d
@printf 'Waiting for MySQL to be ready...\n'
@for i in $$(seq 1 30); do \
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
printf 'MySQL is ready\n'; \
exit 0; \
fi; \
printf 'Waiting... (%s/30)\n' $$i; \
sleep 1; \
done; \
printf 'MySQL failed to start\n'; \
exit 1
mysql-down:
@printf 'Stopping MySQL test container...\n'
cd tests/mysql && docker-compose down -v
mysql-test:
@set -e; \
$(MAKE) mysql-up; \
trap '$(MAKE) mysql-down' EXIT; \
go test -tags=mysql ./tests/mysql/... -v -count=1
mysql-test-quick:
@printf 'Running MySQL tests without container management...\n'
go test -tags=mysql ./tests/mysql/... -v -count=1

View File

@@ -4,21 +4,67 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
## 功能特性 ## 功能特性
- 支持 OpenAI 协议(`/openai/v1/...` - 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`
- 支持 Anthropic 协议(`/anthropic/v1/...` - 支持 Anthropic 协议(`/anthropic/v1/...`
- 支持 Hub-and-Spoke 跨协议双向转换OpenAI ↔ Anthropic - 支持 Hub-and-Spoke 跨协议双向转换OpenAI ↔ Anthropic
- 同协议透传(零语义损失、零序列化开销 - 同协议透传(跳过 Canonical 全量转换,保持协议语义
- 支持流式响应SSE - 支持流式响应SSE
- 支持 Function Calling / Tools - 支持 Function Calling / Tools
- 支持 Thinking / Reasoning - 支持 Thinking / Reasoning
- 支持扩展层接口Models、Embeddings、Rerank - 支持扩展层接口Models、Embeddings、Rerank
- 多供应商配置和路由 - 多供应商配置和路由
- 用量统计 - 用量统计
- 结构化日志zap + lumberjack - 结构化日志zap + lumberjack + 模块标识
- YAML 配置管理 - YAML 配置管理
- 请求验证 - 请求验证
- 中间件支持(请求 ID、日志、恢复、CORS - 中间件支持(请求 ID、日志、恢复、CORS
## 日志规范
### 模块标识
每个模块通过依赖注入获取带模块标识的 logger
```go
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
return &ProxyHandler{
logger: pkglogger.WithModule(logger, "handler.proxy"),
}
}
```
输出格式:
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
### 模块命名规范
| 模块 | 命名 |
|------|------|
| ProxyHandler | `handler.proxy` |
| ProviderHandler | `handler.provider` |
| Provider Client | `provider.client` |
| ConversionEngine | `conversion.engine` |
| RoutingCache | `service.routing_cache` |
| StatsBuffer | `service.stats_buffer` |
| Database | `database` |
### 标准字段
使用 `pkg/logger/field.go` 中定义的字段构造函数:
```go
logger.Debug("请求开始",
pkglogger.Method("POST"),
pkglogger.Path("/v1/chat"),
pkglogger.RequestID("xxx"),
)
```
### GORM 日志
GORM 日志自动桥接到 zapSQL 查询映射到 Debug 级别。
## 技术栈 ## 技术栈
- **语言**: Go 1.26+ - **语言**: Go 1.26+
@@ -105,9 +151,13 @@ backend/
│ │ ├── errors.go │ │ ├── errors.go
│ │ └── wrap.go │ │ └── wrap.go
│ ├── logger/ # 日志系统 │ ├── logger/ # 日志系统
│ │ ├── logger.go │ │ ├── logger.go # 核心初始化
│ │ ├── rotate.go │ │ ├── field.go # 标准字段定义
│ │ ── context.go │ │ ── module.go # 模块日志器
│ │ ├── context.go # Context 辅助函数
│ │ ├── gorm.go # GORM 适配器
│ │ ├── minimal.go # 最小化 logger
│ │ └── rotate.go # 日志轮转
│ ├── modelid/ # 统一模型 ID 工具包 │ ├── modelid/ # 统一模型 ID 工具包
│ │ ├── model_id.go │ │ ├── model_id.go
│ │ └── model_id_test.go │ │ └── model_id_test.go
@@ -170,7 +220,7 @@ OpenAI Response ← Canonical Response ← Anthropic Response
### Smart Passthrough 机制 ### Smart Passthrough 机制
同协议请求走 Smart Passthrough 路径,**零序列化开销** 同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换
``` ```
1. 检测 clientProtocol == providerProtocol 1. 检测 clientProtocol == providerProtocol
@@ -179,12 +229,14 @@ OpenAI Response ← Canonical Response ← Anthropic Response
4. 响应中仅改写 model 字段upstream_model_name → unified_id 4. 响应中仅改写 model 字段upstream_model_name → unified_id
``` ```
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
### 流式转换器层次 ### 流式转换器层次
``` ```
StreamConverter (接口) StreamConverter (接口)
├── PassthroughStreamConverter # 直接透传,无任何处理 ├── PassthroughStreamConverter # 直接透传,无任何处理
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model ├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
└── CanonicalStreamConverter # 跨协议完整转换decode → encode └── CanonicalStreamConverter # 跨协议完整转换decode → encode
``` ```
@@ -251,6 +303,7 @@ StreamConverter (接口)
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 | | `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
| `ENCODING_FAILURE` | 编码失败 | | `ENCODING_FAILURE` | 编码失败 |
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings | | `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings |
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
### AppError 预定义错误 ### AppError 预定义错误
@@ -384,24 +437,37 @@ docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
## 测试 ## 测试
```bash ```bash
# 运行所有测试 # 运行 backend 默认测试
make test make test
# 分类测试
make test-unit
make test-integration
# 生成覆盖率报告 # 生成覆盖率报告
make test-coverage make test-coverage
# MySQL 专项测试
make mysql-up
make mysql-down
make mysql-test
make mysql-test-quick
``` ```
## 数据库迁移 ## 数据库迁移
```bash ```bash
# 使用 Makefile # 使用 Makefile
make migrate-up DB_PATH=~/.nex/config.db make migrate-up DB_DSN=~/.nex/config.db
make migrate-down DB_PATH=~/.nex/config.db make migrate-down DB_DSN=~/.nex/config.db
make migrate-status DB_PATH=~/.nex/config.db make migrate-status DB_DSN=~/.nex/config.db
# 创建新迁移 # 创建新迁移
make migrate-create make migrate-create
# MySQL 迁移
make migrate-up DB_DRIVER=mysql DB_DSN='user:pass@tcp(localhost:3306)/nex'
# 或直接使用 goose # 或直接使用 goose
goose -dir migrations sqlite3 ~/.nex/config.db up goose -dir migrations sqlite3 ~/.nex/config.db up
``` ```
@@ -410,15 +476,15 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
### 代理接口 ### 代理接口
使用 `/{protocol}/v1/{path}` URL 前缀路由 使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath由对应 adapter 识别和组合上游 URL。
#### OpenAI 协议 #### OpenAI 协议
``` ```
POST /openai/chat/completions POST /openai/v1/chat/completions
GET /openai/models GET /openai/v1/models
POST /openai/embeddings POST /openai/v1/embeddings
POST /openai/rerank POST /openai/v1/rerank
``` ```
#### Anthropic 协议 #### Anthropic 协议
@@ -428,10 +494,20 @@ POST /anthropic/v1/messages
GET /anthropic/v1/models GET /anthropic/v1/models
``` ```
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销 **协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough跳过 Canonical 全量转换
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。 **统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
**base_url 约定**
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions`OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
**流式透传边界**:同协议无响应 model 改写时 raw passthrough保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`
### 管理接口 ### 管理接口
#### 供应商管理 #### 供应商管理
@@ -459,7 +535,7 @@ GET /anthropic/v1/models
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com` - Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
**对外 URL 格式** **对外 URL 格式**
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions``/openai/models``/openai/embeddings` - OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions``/openai/v1/models``/openai/v1/embeddings`
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages``/anthropic/v1/models` - Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages``/anthropic/v1/models`
#### 模型管理 #### 模型管理
@@ -501,6 +577,20 @@ GET /anthropic/v1/models
查询参数:`provider_id``model_name``start_date`YYYY-MM-DD`end_date``group_by`provider/model/date 查询参数:`provider_id``model_name``start_date`YYYY-MM-DD`end_date``group_by`provider/model/date
#### 版本信息
- `GET /api/version` - 获取后端构建版本信息
响应字段来源于构建阶段注入的 `buildinfo` 元数据:
```json
{
"version": "0.1.0",
"commit": "abc1234",
"build_time": "2026-05-05T00:00:00Z"
}
```
#### 健康检查 #### 健康检查
- `GET /health` - 返回 `{"status": "ok"}` - `GET /health` - 返回 `{"status": "ok"}`
@@ -508,9 +598,12 @@ GET /anthropic/v1/models
## 开发 ## 开发
```bash ```bash
make build # 构建 make build # 构建 backend/bin/server
make run # 运行后端服务
make lint # 代码检查 make lint # 代码检查
make deps # 整理依赖 make clean # 清理 backend 构建产物
go mod tidy # 整理依赖
go generate ./... # 刷新 mock 等生成代码
``` ```
环境要求Go 1.26 或更高版本 环境要求Go 1.26 或更高版本
@@ -559,6 +652,7 @@ err := v.Validate(myStruct)
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节 - **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接 - **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配 - **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()` - **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片 - **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -0,0 +1,25 @@
//go:build darwin
package main
import (
"fmt"
"os/exec"
"strings"
"go.uber.org/zap"
)
func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title))
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
}
func escapeAppleScript(s string) string {
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "\"", "\\\"")
return s
}

View File

@@ -0,0 +1,67 @@
//go:build linux
package main
import (
"fmt"
"os/exec"
"sync"
)
type dialogToolType int
const (
toolNone dialogToolType = iota
toolZenity
toolKdialog
toolNotifySend
toolXmessage
)
var (
dialogTool dialogToolType
dialogToolOnce sync.Once
)
func init() {
dialogToolOnce.Do(detectDialogTool)
}
func detectDialogTool() {
tools := []struct {
name string
typ dialogToolType
}{
{"zenity", toolZenity},
{"kdialog", toolKdialog},
{"notify-send", toolNotifySend},
{"xmessage", toolXmessage},
}
for _, tool := range tools {
if _, err := exec.LookPath(tool.name); err == nil {
dialogTool = tool.typ
return
}
}
dialogTool = toolNone
}
func showError(title, message string) {
switch dialogTool {
case toolZenity:
exec.Command("zenity", "--error",
fmt.Sprintf("--title=%s", title),
fmt.Sprintf("--text=%s", message)).Run()
case toolKdialog:
exec.Command("kdialog", "--error", message, "--title", title).Run()
case toolNotifySend:
exec.Command("notify-send", "-u", "critical", title, message).Run()
case toolXmessage:
exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run()
default:
dialogLogger().Error("无法显示错误对话框")
}
}

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -0,0 +1,62 @@
//go:build windows
package main
import (
"errors"
"fmt"
"syscall"
"unsafe"
"go.uber.org/zap"
)
const (
mbIconError = 0x10
mbIconInformation = 0x40
)
var (
user32 = syscall.NewLazyDLL("user32.dll")
procMessageBoxW = user32.NewProc("MessageBoxW")
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
return ret, err
}
)
func showError(title, message string) {
if err := messageBox(title, message, mbIconError); err != nil {
if zapLogger != nil {
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
}
}
}
func messageBox(title, message string, flags uint) error {
titlePtr, err := syscall.UTF16PtrFromString(title)
if err != nil {
return err
}
messagePtr, err := syscall.UTF16PtrFromString(message)
if err != nil {
return err
}
ret, callErr := callMessageBoxW(
0,
uintptr(unsafe.Pointer(messagePtr)),
uintptr(unsafe.Pointer(titlePtr)),
uintptr(flags),
)
if ret != 0 {
return nil
}
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
return callErr
}
return fmt.Errorf("MessageBoxW 调用失败")
}

View File

@@ -0,0 +1,33 @@
package main
import (
"runtime"
"testing"
"nex/embedfs"
)
func TestIconSelection_Windows(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("图标格式选择测试仅在 Windows 上运行")
}
if err := testIconLoad("assets/icon.ico"); err != nil {
t.Fatalf("Windows 应加载 .ico 文件: %v", err)
}
}
func TestIconSelection_NonWindows(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("图标格式选择测试在非 Windows 平台运行")
}
if err := testIconLoad("assets/icon.png"); err != nil {
t.Fatalf("非 Windows 平台应加载 .png 文件: %v", err)
}
}
func testIconLoad(path string) error {
_, err := embedfs.Assets.ReadFile(path)
return err
}

View File

@@ -0,0 +1 @@
1 ICON "../../../assets/icon.ico"

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"io/fs" "io/fs"
"log"
"net" "net"
"net/http" "net/http"
"os" "os"
@@ -14,10 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin" "nex/embedfs"
"github.com/getlantern/systray"
"github.com/gofrs/flock"
"go.uber.org/zap"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -29,9 +25,14 @@ import (
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger" "nex/backend/pkg/buildinfo"
"nex/embedfs" "github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
) )
var ( var (
@@ -44,25 +45,32 @@ var (
func main() { func main() {
port := 9826 port := 9826
minimalLogger := pkgLogger.NewMinimal()
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock")) singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
if err := singleLock.Lock(); err != nil { if err := singleLock.Lock(); err != nil {
showError("Nex Gateway", "已有 Nex 实例运行") minimalLogger.Error("已有 Nex 实例运行")
showError(appName, "已有 Nex 实例运行")
os.Exit(1) os.Exit(1)
} }
defer singleLock.Unlock() defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil { if err := checkPortAvailable(port); err != nil {
showError("Nex Gateway", err.Error()) minimalLogger.Error("端口不可用", zap.Error(err))
os.Exit(1) showError(appName, err.Error())
return
} }
cfg, err := config.LoadConfig() cfg, err := config.LoadConfig()
if err != nil { if err != nil {
showError("Nex Gateway", fmt.Sprintf("加载配置失败: %v", err)) minimalLogger.Fatal("加载配置失败", zap.Error(err))
os.Exit(1)
} }
zapLogger, err = pkgLogger.New(pkgLogger.Config{ zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
Level: cfg.Log.Level, Level: cfg.Log.Level,
Path: cfg.Log.Path, Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize, MaxSize: cfg.Log.MaxSize,
@@ -71,15 +79,19 @@ func main() {
Compress: cfg.Log.Compress, Compress: cfg.Log.Compress,
}) })
if err != nil { if err != nil {
showError("Nex Gateway", fmt.Sprintf("初始化日志失败: %v", err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
os.Exit(1)
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger)
db, err := database.Init(&cfg.Database, zapLogger) db, err := database.Init(&cfg.Database, zapLogger)
if err != nil { if err != nil {
showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err)) zapLogger.Fatal("初始化数据库失败", zap.Error(err))
os.Exit(1)
} }
defer database.Close(db) defer database.Close(db)
@@ -105,19 +117,20 @@ func main() {
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil { if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error())) zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
} }
if err := registry.Register(anthropic.NewAdapter()); err != nil { if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error())) zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
} }
engine := conversion.NewConversionEngine(registry, zapLogger) engine := conversion.NewConversionEngine(registry, zapLogger)
providerClient := provider.NewClient() providerClient := provider.NewClient(zapLogger)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
providerHandler := handler.NewProviderHandler(providerService) providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService) modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService) statsHandler := handler.NewStatsHandler(statsService)
versionHandler := handler.NewVersionHandler()
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
r := gin.New() r := gin.New()
@@ -127,7 +140,7 @@ func main() {
r.Use(middleware.Logging(zapLogger)) r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS()) r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler) setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
setupStaticFiles(r) setupStaticFiles(r)
server = &http.Server{ server = &http.Server{
@@ -140,24 +153,30 @@ func main() {
shutdownCtx, shutdownCancel = context.WithCancel(context.Background()) shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
go func() { go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr)) zapLogger.Info("AI Gateway 启动",
zap.String("addr", server.Addr),
zap.String("version", buildinfo.Version()),
zap.String("commit", buildinfo.Commit()),
zap.String("build_time", buildinfo.BuildTime()))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) zapLogger.Fatal("服务器启动失败", zap.Error(err))
} }
}() }()
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil { if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error())) zapLogger.Warn("无法打开浏览器", zap.Error(err))
} }
}() }()
setupSystray(port) setupSystray(port)
} }
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
r.Any("/v1/*path", proxyHandler.HandleProxy) r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
r.GET("/api/version", versionHandler.GetVersion)
providers := r.Group("/api/providers") providers := r.Group("/api/providers")
{ {
@@ -188,12 +207,26 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
}) })
} }
func setupStaticFiles(r *gin.Engine) { func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist") return func(c *gin.Context) {
if err != nil { c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error())) next(c)
} }
}
func setupStaticFiles(r *gin.Engine) {
distFS, err := frontendDistFS()
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
}
setupStaticFilesWithFS(r, distFS)
}
func frontendDistFS() (fs.FS, error) {
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
}
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
getContentType := func(path string) string { getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") { if strings.HasSuffix(path, ".js") {
return "application/javascript" return "application/javascript"
@@ -226,20 +259,23 @@ func setupStaticFiles(r *gin.Engine) {
c.Data(200, getContentType(filepath), data) c.Data(200, getContentType(filepath), data)
}) })
r.GET("/favicon.svg", func(c *gin.Context) { r.GET("/icon.png", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg") data, err := fs.ReadFile(distFS, "icon.png")
if err != nil { if err != nil {
c.Status(404) c.Status(404)
return return
} }
c.Data(200, "image/svg+xml", data) c.Data(200, "image/png", data)
}) })
r.NoRoute(func(c *gin.Context) { r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") || if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/openai/") ||
strings.HasPrefix(path, "/anthropic/") ||
path == "/openai" ||
path == "/anthropic" ||
strings.HasPrefix(path, "/health") { strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"}) c.JSON(404, gin.H{"error": "not found"})
return return
@@ -256,13 +292,18 @@ func setupStaticFiles(r *gin.Engine) {
func setupSystray(port int) { func setupSystray(port int) {
systray.Run(func() { systray.Run(func() {
icon, err := embedfs.Assets.ReadFile("assets/icon.png") var icon []byte
var err error
if runtime.GOOS == "windows" {
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
} else {
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
}
if err != nil { if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error())) zapLogger.Error("无法加载托盘图标", zap.Error(err))
} }
systray.SetIcon(icon) systray.SetIcon(icon)
systray.SetTitle("Nex Gateway") systray.SetTooltip(appTooltip)
systray.SetTooltip("AI Gateway")
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开") mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
systray.AddSeparator() systray.AddSeparator()
@@ -271,17 +312,15 @@ func setupSystray(port int) {
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "") mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
mPort.Disable() mPort.Disable()
systray.AddSeparator() systray.AddSeparator()
mAbout := systray.AddMenuItem("关于", "")
systray.AddSeparator()
mQuit := systray.AddMenuItem("退出", "停止服务并退出") mQuit := systray.AddMenuItem("退出", "停止服务并退出")
go func() { go func() {
for { for {
select { select {
case <-mOpen.ClickedCh: case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port)) if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
case <-mAbout.ClickedCh: zapLogger.Warn("打开浏览器失败", zap.Error(err))
showAbout() }
case <-mQuit.ClickedCh: case <-mQuit.ClickedCh:
doShutdown() doShutdown()
systray.Quit() systray.Quit()
@@ -300,7 +339,9 @@ func doShutdown() {
if server != nil { if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
server.Shutdown(ctx) if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
} }
if shutdownCancel != nil { if shutdownCancel != nil {
@@ -338,8 +379,8 @@ func (s *SingletonLock) Lock() error {
return nil return nil
} }
func (s *SingletonLock) Unlock() { func (s *SingletonLock) Unlock() error {
s.flock.Unlock() return s.flock.Unlock()
} }
func openBrowser(url string) error { func openBrowser(url string) error {
@@ -366,28 +407,3 @@ func openBrowser(url string) error {
return cmd.Start() return cmd.Start()
} }
func showError(title, message string) {
switch runtime.GOOS {
case "darwin":
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, message, title)
exec.Command("osascript", "-e", script).Run()
case "windows":
exec.Command("msg", "*", message).Run()
case "linux":
exec.Command("zenity", "--error", fmt.Sprintf("--title=%s", title), fmt.Sprintf("--text=%s", message)).Run()
}
}
func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
switch runtime.GOOS {
case "darwin":
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, message)
exec.Command("osascript", "-e", script).Run()
case "windows":
exec.Command("msg", "*", message).Run()
case "linux":
exec.Command("zenity", "--info", "--title=关于 Nex Gateway", fmt.Sprintf("--text=%s", message)).Run()
}
}

View File

@@ -0,0 +1,61 @@
//go:build windows
package main
import (
"errors"
"syscall"
"testing"
)
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
t.Helper()
old := callMessageBoxW
callMessageBoxW = fn
t.Cleanup(func() {
callMessageBoxW = old
})
}
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
if err == nil {
t.Fatal("包含 NUL 字符时应该返回错误")
}
}
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 1, syscall.Errno(123)
})
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
}
}
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
err := messageBox("测试标题", "测试消息", mbIconInformation)
if !errors.Is(err, syscall.Errno(5)) {
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
}
}
func TestShowError_WindowsBranch(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
defer func() {
if recovered := recover(); recovered != nil {
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
}
}()
showError("测试错误", "这是一条测试错误消息")
}

View File

@@ -0,0 +1,9 @@
package main
const (
appName = "Nex"
appTooltip = appName
appDescription = "AI Gateway - 统一的大模型 API 网关"
// #nosec G101 -- 项目官网地址不是凭据
appWebsite = "https://github.com/nex/gateway"
)

View File

@@ -0,0 +1,13 @@
package main
import "testing"
func TestDesktopMetadata(t *testing.T) {
if appName != "Nex" {
t.Fatalf("appName = %q, want %q", appName, "Nex")
}
if appTooltip != appName {
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
}
}

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"net" "net"
"net/http" "net/http"
"testing" "testing"
@@ -21,19 +22,12 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) { func TestCheckPortOccupied(t *testing.T) {
port := 19827 port := 19827
listener, err := net.Listen("tcp", ":19827") listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
defer listener.Close() defer listener.Close()
go func() {
conn, err := listener.Accept()
if err == nil {
conn.Close()
}
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port) err = checkPortAvailable(port)
@@ -47,13 +41,19 @@ func TestCheckPortOccupied(t *testing.T) {
func TestCheckPortAvailableAfterClose(t *testing.T) { func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828 port := 19828
listener, err := net.Listen("tcp", ":19828") listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
server := &http.Server{} server := &http.Server{ReadHeaderTimeout: time.Second}
go server.Serve(listener) defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)

View File

@@ -0,0 +1,44 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"testing/fstest"
"nex/backend/internal/handler"
"github.com/gin-gonic/gin"
)
func TestSetupRoutes_VersionDoesNotFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
setupStaticFilesWithFS(r, fstest.MapFS{
"index.html": {Data: []byte("<html>fallback</html>")},
})
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
if contentType := w.Header().Get("Content-Type"); contentType == "text/html; charset=utf-8" {
t.Fatalf("版本接口不应返回 SPA fallback HTML")
}
var result map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
for _, key := range []string{"version", "commit", "build_time"} {
if result[key] == "" {
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
}
}
}

View File

@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
if err := lock.Lock(); err != nil { if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err) t.Fatalf("首次加锁应成功,但返回错误: %v", err)
} }
defer lock.Unlock() defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
} }
func TestSingletonLock_DuplicateLockFails(t *testing.T) { func TestSingletonLock_DuplicateLockFails(t *testing.T) {
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
defer lock1.Unlock() defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
err := lock2.Lock() err := lock2.Lock()
if err == nil { if err == nil {
lock2.Unlock() if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil") t.Fatal("重复加锁应失败,但返回 nil")
} }
} }
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
lock1.Unlock() if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil { if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err) t.Fatalf("释放后重新加锁应成功: %v", err)
} }
lock2.Unlock() if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
} }
func TestSingletonLock_UnlockWithoutLock(t *testing.T) { func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock")) lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
lock.Unlock() if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
} }

View File

@@ -1,73 +1,26 @@
package main package main
import ( import (
"io/fs" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"testing/fstest"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"nex/embedfs"
) )
func TestSetupStaticFiles(t *testing.T) { func TestSetupStaticFiles(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist") distFS, err := frontendDistFS()
if err != nil { if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err) t.Skipf("跳过测试: 前端资源未构建: %v", err)
return return
} }
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
}
if strings.HasSuffix(path, ".css") {
return "text/css"
}
if strings.HasSuffix(path, ".svg") {
return "image/svg+xml"
}
return "application/octet-stream"
}
r := gin.New() r := gin.New()
r.GET("/assets/*filepath", func(c *gin.Context) { setupStaticFilesWithFS(r, distFS)
filepath := c.Param("filepath")
data, err := fs.ReadFile(distFS, "assets"+filepath)
if err != nil {
c.Status(404)
return
}
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
}
data, err := fs.ReadFile(distFS, "index.html")
if err != nil {
c.Status(500)
return
}
c.Data(200, "text/html; charset=utf-8", data)
})
t.Run("API 404", func(t *testing.T) { t.Run("API 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil) req := httptest.NewRequest("GET", "/api/test", nil)
@@ -79,6 +32,32 @@ func TestSetupStaticFiles(t *testing.T) {
} }
}) })
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/anthropic/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("SPA fallback", func(t *testing.T) { t.Run("SPA fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil) req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -121,3 +100,139 @@ func TestSetupStaticFiles(t *testing.T) {
t.Log("静态文件服务测试通过") t.Log("静态文件服务测试通过")
} }
func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupStaticFilesWithFS(r, fstest.MapFS{
"icon.png": {Data: []byte("png")},
"index.html": {Data: []byte("<html>fallback</html>")},
})
req := httptest.NewRequest("GET", "/icon.png", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
}
if w.Header().Get("Content-Type") != "image/png" {
t.Fatalf("期望 Content-Type image/png, 实际 %s", w.Header().Get("Content-Type"))
}
if w.Body.String() != "png" {
t.Fatalf("期望返回 PNG 内容,实际 %q", w.Body.String())
}
}
func TestWithProtocolAndStaticRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
var gotProtocol string
var gotPath string
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
setupStaticFilesWithFS(r, distFS)
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "openai" {
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
}
if gotPath != "/v1/chat/completions" {
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
}
})
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "anthropic" {
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
}
if gotPath != "/v1/messages" {
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
}
})
t.Run("Static assets are not hijacked", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("GET", "/assets/test.js", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if gotProtocol != "" || gotPath != "" {
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
if w.Code == http.StatusOK {
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
}
return
}
if w.Code != http.StatusNotFound {
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
}
})
t.Run("SPA path keeps fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
t.Errorf("期望返回 HTML实际 %s", w.Header().Get("Content-Type"))
}
})
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/unknown", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
}
if gotProtocol != "openai" || gotPath != "/unknown" {
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
})
}

View File

@@ -3,7 +3,6 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@@ -23,18 +22,19 @@ import (
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"nex/backend/pkg/buildinfo"
pkgLogger "nex/backend/pkg/logger" pkgLogger "nex/backend/pkg/logger"
) )
func main() { func main() {
minimalLogger := pkgLogger.NewMinimal()
cfg, err := config.LoadConfig() cfg, err := config.LoadConfig()
if err != nil { if err != nil {
log.Fatalf("加载配置失败: %v", err) minimalLogger.Fatal("加载配置失败", zap.Error(err))
} }
cfg.PrintSummary() zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
zapLogger, err := pkgLogger.New(pkgLogger.Config{
Level: cfg.Log.Level, Level: cfg.Log.Level,
Path: cfg.Log.Path, Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize, MaxSize: cfg.Log.MaxSize,
@@ -43,13 +43,19 @@ func main() {
Compress: cfg.Log.Compress, Compress: cfg.Log.Compress,
}) })
if err != nil { if err != nil {
log.Fatalf("初始化日志失败: %v", err) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger)
db, err := database.Init(&cfg.Database, zapLogger) db, err := database.Init(&cfg.Database, zapLogger)
if err != nil { if err != nil {
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error())) zapLogger.Fatal("初始化数据库失败", zap.Error(err))
} }
defer database.Close(db) defer database.Close(db)
@@ -74,19 +80,20 @@ func main() {
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil { if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error())) zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
} }
if err := registry.Register(anthropic.NewAdapter()); err != nil { if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error())) zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
} }
engine := conversion.NewConversionEngine(registry, zapLogger) engine := conversion.NewConversionEngine(registry, zapLogger)
providerClient := provider.NewClient() providerClient := provider.NewClient(zapLogger)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService) proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
providerHandler := handler.NewProviderHandler(providerService) providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService) modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService) statsHandler := handler.NewStatsHandler(statsService)
versionHandler := handler.NewVersionHandler()
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
r := gin.New() r := gin.New()
@@ -96,7 +103,7 @@ func main() {
r.Use(middleware.Logging(zapLogger)) r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS()) r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler) setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
srv := &http.Server{ srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port), Addr: fmt.Sprintf(":%d", cfg.Server.Port),
@@ -106,9 +113,13 @@ func main() {
} }
go func() { go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr)) zapLogger.Info("AI Gateway 启动",
zap.String("addr", srv.Addr),
zap.String("version", buildinfo.Version()),
zap.String("commit", buildinfo.Commit()),
zap.String("build_time", buildinfo.BuildTime()))
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) zapLogger.Fatal("服务器启动失败", zap.Error(err))
} }
}() }()
@@ -122,7 +133,7 @@ func main() {
defer cancel() defer cancel()
if err := srv.Shutdown(ctx); err != nil { if err := srv.Shutdown(ctx); err != nil {
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error())) zapLogger.Fatal("服务器强制关闭", zap.Error(err))
} }
statsBuffer.Stop() statsBuffer.Stop()
@@ -130,8 +141,9 @@ func main() {
zapLogger.Info("服务器已关闭") zapLogger.Info("服务器已关闭")
} }
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
r.Any("/:protocol/*path", proxyHandler.HandleProxy) r.Any("/:protocol/*path", proxyHandler.HandleProxy)
r.GET("/api/version", versionHandler.GetVersion)
providers := r.Group("/api/providers") providers := r.Group("/api/providers")
{ {

View File

@@ -0,0 +1,37 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"nex/backend/internal/handler"
"github.com/gin-gonic/gin"
)
func TestSetupRoutes_Version(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
var result map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
for _, key := range []string{"version", "commit", "build_time"} {
if result[key] == "" {
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
}
}
}

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -11,6 +12,7 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"github.com/spf13/viper" "github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
@@ -57,7 +59,10 @@ type LogConfig struct {
// DefaultConfig returns default config values // DefaultConfig returns default config values
func DefaultConfig() *Config { func DefaultConfig() *Config {
// Use home dir for default paths // Use home dir for default paths
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
return &Config{ return &Config{
@@ -96,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err return "", err
} }
configDir := filepath.Join(homeDir, ".nex") configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil { if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err return "", err
} }
return configDir, nil return configDir, nil
@@ -122,7 +127,10 @@ func GetConfigPath() (string, error) {
// setupDefaults 设置默认配置值 // setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) { func setupDefaults(v *viper.Viper) {
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826) v.SetDefault("server.port", 9826)
@@ -176,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 绑定所有 flag 到 viper // 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定 // 注意:必须在设置默认值之后绑定
v.BindPFlag("server.port", flagSet.Lookup("server-port")) bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout")) bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout")) bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver")) bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path")) bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host")) bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port")) bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user")) bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password")) bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname")) bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns")) bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns")) bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime")) bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
v.BindPFlag("log.level", flagSet.Lookup("log-level")) bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
v.BindPFlag("log.path", flagSet.Lookup("log-path")) bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size")) bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups")) bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age")) bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
v.BindPFlag("log.compress", flagSet.Lookup("log-compress")) bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
} }
// setupEnv 绑定环境变量 // setupEnv 绑定环境变量
@@ -217,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
// 配置文件不存在,创建默认配置文件 // 配置文件不存在,创建默认配置文件
if err := v.SafeWriteConfig(); err != nil { writeErr := v.SafeWriteConfigAs(configPath)
// 忽略写入错误(可能目录已存在等) if writeErr == nil {
return nil return nil
} }
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
} }
return nil return nil
} }
@@ -245,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
setupFlags(v, flagSet) setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数) // 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
flagSet.Parse(os.Args[1:]) if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
// 4. 获取配置文件路径(可能被 --config 参数覆盖) // 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" { if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
@@ -294,11 +317,11 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists // Ensure directory exists
dir := filepath.Dir(configPath) dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
return os.WriteFile(configPath, data, 0600) return os.WriteFile(configPath, data, 0o600)
} }
// Validate validates the config // Validate validates the config
@@ -311,22 +334,24 @@ func (c *Config) Validate() error {
} }
// PrintSummary 打印配置摘要 // PrintSummary 打印配置摘要
func (c *Config) PrintSummary() { func (c *Config) PrintSummary(logger *zap.Logger) {
fmt.Println("\nAI Gateway 启动配置") logger.Info("AI Gateway 启动配置",
fmt.Println("==================") zap.Int("server_port", c.Server.Port),
fmt.Printf("服务器端口: %d\n", c.Server.Port) zap.String("database_driver", c.Database.Driver),
zap.String("log_level", c.Log.Level),
)
if c.Database.Driver == "mysql" { if c.Database.Driver == "mysql" {
fmt.Printf("数据库类型: mysql\n") logger.Info("数据库配置",
fmt.Printf("数据库地址: %s:%d/%s\n", c.Database.Host, c.Database.Port, c.Database.DBName) zap.String("driver", "mysql"),
zap.String("host", c.Database.Host),
zap.Int("port", c.Database.Port),
zap.String("database", c.Database.DBName),
)
} else { } else {
fmt.Printf("数据库类型: sqlite\n") logger.Info("数据库配置",
fmt.Printf("数据库路径: %s\n", c.Database.Path) zap.String("driver", "sqlite"),
zap.String("path", c.Database.Path),
)
} }
fmt.Printf("日志级别: %s\n", c.Log.Level)
fmt.Println("\n配置来源:")
configPath, _ := GetConfigPath()
fmt.Printf(" 配置文件: %s\n", configPath)
fmt.Println(" 环境变量: 待统计")
fmt.Println(" CLI 参数: 待统计")
fmt.Println()
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -171,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
err := cfg.Validate() err := cfg.Validate()
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg) assert.Contains(t, err.Error(), tt.errMsg)
}
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
} }
@@ -233,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml") configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg) data, err := yaml.Marshal(cfg)
require.NoError(t, err) require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644) err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err) require.NoError(t, err)
// 加载配置 // 加载配置
@@ -302,7 +305,7 @@ func TestPrintSummary(t *testing.T) {
t.Run("SQLite模式摘要", func(t *testing.T) { t.Run("SQLite模式摘要", func(t *testing.T) {
cfg := DefaultConfig() cfg := DefaultConfig()
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
cfg.PrintSummary() cfg.PrintSummary(zap.NewNop())
}) })
}) })
t.Run("MySQL模式摘要", func(t *testing.T) { t.Run("MySQL模式摘要", func(t *testing.T) {
@@ -313,7 +316,7 @@ func TestPrintSummary(t *testing.T) {
cfg.Database.User = "nex" cfg.Database.User = "nex"
cfg.Database.DBName = "nex" cfg.Database.DBName = "nex"
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
cfg.PrintSummary() cfg.PrintSummary(zap.NewNop())
}) })
}) })
} }

View File

@@ -29,8 +29,8 @@ type Model struct {
// UsageStats 用量统计 // UsageStats 用量统计
type UsageStats struct { type UsageStats struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"` ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"` ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"` ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
RequestCount int `gorm:"default:0" json:"request_count"` RequestCount int `gorm:"default:0" json:"request_count"`
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"` Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
} }
@@ -47,4 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string { func (UsageStats) TableName() string {
return "usage_stats" return "usage_stats"
} }

View File

@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message, Message: err.Message,
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat: case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加 // Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
default: default:
return body, nil return body, nil

View File

@@ -2,6 +2,7 @@ package anthropic
import ( import (
"encoding/json" "encoding/json"
"errors"
"testing" "testing"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -48,6 +49,28 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
} }
} }
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
a := NewAdapter()
// docs/api_reference/anthropic defines messages and models under /v1.
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/v1/messages", conversion.InterfaceTypeChat},
{"/v1/models", conversion.InterfaceTypeModels},
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
{"/messages", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) { func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
@@ -141,8 +164,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) { t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`)) _, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -150,24 +173,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider) _, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码嵌入响应", func(t *testing.T) { t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`)) _, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码嵌入响应", func(t *testing.T) { t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{}) _, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }
@@ -178,8 +201,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) { t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`)) _, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -187,24 +210,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider) _, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码重排序响应", func(t *testing.T) { t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`)) _, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码重排序响应", func(t *testing.T) { t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{}) _, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages { for _, msg := range req.Messages {
decoded := decodeMessage(msg) decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...) canonicalMsgs = append(canonicalMsgs, decoded...)
} }
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
} }
// decodeMessage 解码 Anthropic 消息 // decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage { func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role { switch msg.Role {
case "user": case "user":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock var others []canonical.ContentBlock
for _, b := range blocks { for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 { if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}}) result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
} }
return result return result, nil
case "assistant": case "assistant":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 { if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock("")) blocks = append(blocks, canonical.NewTextBlock(""))
} }
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}} return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
} }
return nil return nil, nil
} }
// decodeContentBlocks 解码内容块列表 // decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock { func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) { switch v := content.(type) {
case string: case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)} return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any: case []any:
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m) block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil { if block != nil {
blocks = append(blocks, *block) blocks = append(blocks, *block)
} }
} }
} }
if len(blocks) > 0 { if len(blocks) > 0 {
return blocks return blocks, nil
} }
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil: case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default: default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))} return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
} }
} }
// decodeSingleContentBlock 解码单个内容块 // decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock { func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text} if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use": case "tool_use":
id, _ := m["id"].(string) id, ok := m["id"].(string)
name, _ := m["name"].(string) if !ok {
input, _ := json.Marshal(m["input"]) id = ""
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input} }
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result": case "tool_result":
toolUseID, _ := m["tool_use_id"].(string) toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false isErr := false
if ie, ok := m["is_error"].(bool); ok { if ie, ok := m["is_error"].(bool); ok {
isErr = ie isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string: case string:
content = json.RawMessage(fmt.Sprintf("%q", cv)) content = json.RawMessage(fmt.Sprintf("%q", cv))
default: default:
content, _ = json.Marshal(cv) encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
} }
} else { } else {
content = json.RawMessage(`""`) content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID, ToolUseID: toolUseID,
Content: content, Content: content,
IsError: &isErr, IsError: &isErr,
} }, nil
case "thinking": case "thinking":
thinking, _ := m["thinking"].(string) thinking, ok := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking} if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking": case "redacted_thinking":
// 丢弃 // 丢弃
return nil return nil, nil
} }
return nil return nil, nil
} }
// decodeTools 解码工具定义 // decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "auto": case "auto":
return canonical.NewToolChoiceAuto() return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any": case "any":
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
case "tool": case "tool":
name, _ := v["name"].(string) name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
} }

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"]) assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"]) assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
} }
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息 // tool 消息应被合并到相邻 user 消息
foundToolResult := false foundToolResult := false
for _, m := range msgs { for _, m := range msgs {
msgMap := m.(map[string]any) msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" { if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any) content, ok := msgMap["content"].([]any)
if ok { if ok {
for _, c := range content { for _, c := range content {
block := c.(map[string]any) block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" { if block["type"] == "tool_result" {
foundToolResult = true foundToolResult = true
} }
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any) require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"]) assert.Equal(t, "user", firstMsg["role"])
} }
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"]) assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"]) assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"]) assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"]) assert.Equal(t, "你好", block["text"])
} }
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any) data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1) assert.Len(t, data, 1)
model := data[0].(map[string]any) model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"]) assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式 // created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string) createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any) userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"]) assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any) content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2) assert.Len(t, content, 2)
} }
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"] _, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning) assert.False(t, hasReasoning)
} }
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"]) assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"]) assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"]) assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk data := rawChunk
if len(d.utf8Remainder) > 0 { if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...) data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil d.utf8Remainder = nil
} }
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") { for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r") line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ") eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ") eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" { if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData)) chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
} }
eventType = "" eventType = ""
eventData = "" eventData = ""
} else if line == "" { case line == "":
// SSE 事件分隔符 continue
} }
} }

View File

@@ -51,15 +51,23 @@ func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent)
if event.Message != nil { if event.Message != nil {
msg := map[string]any{ msg := map[string]any{
"id": event.Message.ID, "id": event.Message.ID,
"model": event.Message.Model, "type": "message",
"role": "assistant", "role": "assistant",
"content": []any{},
"model": event.Message.Model,
"stop_reason": nil,
"stop_sequence": nil,
} }
if event.Message.Usage != nil { if event.Message.Usage != nil {
usage := map[string]any{ msg["usage"] = map[string]any{
"input_tokens": event.Message.Usage.InputTokens, "input_tokens": event.Message.Usage.InputTokens,
"output_tokens": event.Message.Usage.OutputTokens, "output_tokens": event.Message.Usage.OutputTokens,
} }
msg["usage"] = usage } else {
msg["usage"] = map[string]any{
"input_tokens": 0,
"output_tokens": 0,
}
} }
payload["message"] = msg payload["message"] = msg
} }
@@ -147,6 +155,10 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
payload["usage"] = map[string]any{ payload["usage"] = map[string]any{
"output_tokens": event.Usage.OutputTokens, "output_tokens": event.Usage.OutputTokens,
} }
} else {
payload["usage"] = map[string]any{
"output_tokens": 0,
}
} }
return e.marshalEvent("message_delta", payload) return e.marshalEvent("message_delta", payload)
} }

View File

@@ -21,8 +21,55 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
s := string(chunks[0]) s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: message_start\n")) assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
assert.Contains(t, s, "data: ") assert.Contains(t, s, "data: ")
assert.Contains(t, s, "msg_1")
assert.Contains(t, s, "claude-3") var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "msg_1", msg["id"])
assert.Equal(t, "message", msg["type"])
assert.Equal(t, "assistant", msg["role"])
assert.Equal(t, []any{}, msg["content"])
assert.Equal(t, "claude-3", msg["model"])
assert.Nil(t, msg["stop_reason"])
assert.Nil(t, msg["stop_sequence"])
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(0), usage["input_tokens"])
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(50), usage["output_tokens"])
} }
func TestStreamEncoder_ContentBlockDelta(t *testing.T) { func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
@@ -80,7 +127,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"]) assert.Equal(t, "text", cb["type"])
} }
@@ -107,7 +155,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"]) assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"]) assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"]) assert.Equal(t, "search", cb["name"])
@@ -131,7 +180,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"]) assert.Equal(t, "thinking", cb["type"])
} }
@@ -173,8 +223,13 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break break
} }
} }
delta := payload["delta"].(map[string]any) delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"]) assert.Equal(t, "end_turn", delta["stop_reason"])
usage, oku := payload["usage"].(map[string]any)
require.True(t, oku, "message_delta SHALL always include usage")
assert.Equal(t, float64(0), usage["output_tokens"])
} }
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) { func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
@@ -199,7 +254,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break break
} }
} }
u := payload["usage"].(map[string]any) u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"]) assert.Equal(t, float64(88), u["output_tokens"])
} }

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
} }
func TestDecodeContentBlocks_Nil(t *testing.T) { func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil) blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text) assert.Equal(t, "", blocks[0].Text)
} }
func TestDecodeContentBlocks_String(t *testing.T) { func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello") blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text) assert.Equal(t, "hello", blocks[0].Text)
} }
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice) result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"]) r, ok := result.(map[string]any)
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"]) require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
}) })
} }
} }
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any) tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1) assert.Len(t, tools, 1)
tool := tools[0].(map[string]any) tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"]) assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"]) assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any) tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"]) assert.Equal(t, "auto", tc["type"])
} }
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"]) assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"]) assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"]) assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -3,10 +3,14 @@ package conversion
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/conversion/canonical"
pkglogger "nex/backend/pkg/logger"
) )
// HTTPRequestSpec HTTP 请求规格 // HTTPRequestSpec HTTP 请求规格
@@ -33,13 +37,10 @@ type ConversionEngine struct {
// NewConversionEngine 创建转换引擎 // NewConversionEngine 创建转换引擎
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine { func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
if logger == nil {
logger = zap.L()
}
return &ConversionEngine{ return &ConversionEngine{
registry: registry, registry: registry,
middlewareChain: NewMiddlewareChain(), middlewareChain: NewMiddlewareChain(),
logger: logger, logger: pkglogger.WithModule(logger, "conversion.engine"),
} }
} }
@@ -72,7 +73,7 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
// ConvertHttpRequest 转换 HTTP 请求 // ConvertHttpRequest 转换 HTTP 请求
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) { func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
nativePath := spec.URL nativePath, rawQuery := splitRequestPath(spec.URL)
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
providerAdapter, err := e.registry.Get(providerProtocol) providerAdapter, err := e.registry.Get(providerProtocol)
@@ -90,15 +91,18 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType) rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
if err != nil { if err != nil {
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体", e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
zap.String("error", err.Error()), zap.Error(err),
zap.String("interface", string(interfaceType))) zap.String("interface", string(interfaceType)))
rewrittenBody = spec.Body rewrittenBody = spec.Body
} }
} }
} }
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + nativePath, URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method, Method: spec.Method,
Headers: providerAdapter.BuildHeaders(provider), Headers: providerAdapter.BuildHeaders(provider),
Body: rewrittenBody, Body: rewrittenBody,
@@ -115,7 +119,8 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
interfaceType := clientAdapter.DetectInterfaceType(nativePath) interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType) providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
providerHeaders := providerAdapter.BuildHeaders(provider) providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body) providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil { if err != nil {
@@ -123,7 +128,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl, URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method, Method: spec.Method,
Headers: providerHeaders, Headers: providerHeaders,
Body: providerBody, Body: providerBody,
@@ -135,25 +140,22 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段 // Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 { if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return &spec, nil rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
} if rewriteErr != nil {
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体", e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.String("error", err.Error()), zap.Error(rewriteErr),
zap.String("interface", string(interfaceType))) zap.String("interface", string(interfaceType)))
return &spec, nil } else {
}
return &HTTPResponseSpec{ return &HTTPResponseSpec{
StatusCode: spec.StatusCode, StatusCode: spec.StatusCode,
Headers: spec.Headers, Headers: spec.Headers,
Body: rewrittenBody, Body: rewrittenBody,
}, nil }, nil
} }
}
}
return &spec, nil return &spec, nil
} }
@@ -183,12 +185,11 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段 // Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" { if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return NewPassthroughStreamConverter(), nil
}
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
}
return NewPassthroughStreamConverter(), nil return NewPassthroughStreamConverter(), nil
} }
@@ -203,7 +204,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
ctx := ConversionContext{ ctx := ConversionContext{
ConversionID: uuid.New().String(), ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat, InterfaceType: interfaceType,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
@@ -273,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
canonicalReq, err := clientAdapter.DecodeRequest(body) canonicalReq, err := clientAdapter.DecodeRequest(body)
if err != nil { if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err) return nil, NewRequestJSONParseError("解码请求失败", err)
} }
ctx := NewConversionContext(InterfaceTypeChat) ctx := NewConversionContext(InterfaceTypeChat)
@@ -281,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
if err != nil { if err != nil {
return nil, err return nil, err
} }
if containsUnsupportedMultimodal(canonicalReq) {
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
}
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider) encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
if err != nil { if err != nil {
@@ -292,7 +296,7 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
canonicalResp, err := providerAdapter.DecodeResponse(body) canonicalResp, err := providerAdapter.DecodeResponse(body)
if err != nil { if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err) return nil, NewResponseJSONParseError("解码响应失败", err)
} }
if modelOverride != "" { if modelOverride != "" {
canonicalResp.Model = modelOverride canonicalResp.Model = modelOverride
@@ -307,12 +311,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body) models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelsResponse(models) encoded, err := clientAdapter.EncodeModelsResponse(models)
if err != nil { if err != nil {
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
return encoded, nil return encoded, nil
@@ -321,12 +325,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body) info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelInfoResponse(info) encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil { if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
return encoded, nil return encoded, nil
@@ -335,7 +339,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body) req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeEmbeddingRequest(req, provider) return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -344,7 +348,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body) resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
if modelOverride != "" { if modelOverride != "" {
@@ -356,21 +360,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body) req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeRerankRequest(req, provider) return providerAdapter.EncodeRerankRequest(req, provider)
} }
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body) resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if err != nil { if decodeErr == nil {
return body, nil
}
if modelOverride != "" { if modelOverride != "" {
resp.Model = modelOverride resp.Model = modelOverride
} }
return clientAdapter.EncodeRerankResponse(resp) return clientAdapter.EncodeRerankResponse(resp)
}
return body, nil
} }
// DetectInterfaceType 检测接口类型 // DetectInterfaceType 检测接口类型
@@ -379,6 +384,7 @@ func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string
if err != nil { if err != nil {
return InterfaceTypePassthrough, err return InterfaceTypePassthrough, err
} }
nativePath, _ = splitRequestPath(nativePath)
return adapter.DetectInterfaceType(nativePath), nil return adapter.DetectInterfaceType(nativePath), nil
} }
@@ -392,9 +398,56 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error", "type": "internal_error",
}, },
} }
body, _ := json.Marshal(fallback) body, marshalErr := json.Marshal(fallback)
if marshalErr == nil {
return body, 500, nil return body, 500, nil
} }
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
}
body, statusCode := adapter.EncodeError(err) body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil return body, statusCode, nil
} }
func splitRequestPath(rawPath string) (string, string) {
path, query, found := strings.Cut(rawPath, "?")
if !found {
return rawPath, ""
}
return path, query
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
if strings.Contains(path, "?") {
return path + "&" + rawQuery
}
return path + "?" + rawQuery
}
func joinBaseURL(baseURL, path string) string {
if baseURL == "" {
return path
}
if path == "" {
return baseURL
}
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
if req == nil {
return false
}
for _, msg := range req.Messages {
for _, block := range msg.Content {
switch block.Type {
case "image", "audio", "video", "file":
return true
}
}
}
return false
}

View File

@@ -0,0 +1,63 @@
package conversion_test
import (
"testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
tests := []struct {
name string
adapter conversion.ProtocolAdapter
clientProtocol string
providerProtocol string
baseURL string
nativePath string
expectedURL string
body []byte
}{
{
name: "openai base url includes version path",
adapter: openai.NewAdapter(),
clientProtocol: "openai",
providerProtocol: "openai",
baseURL: "http://example.com/v1",
nativePath: "/chat/completions",
expectedURL: "http://example.com/v1/chat/completions",
body: []byte(`{"model":"gpt-4","messages":[]}`),
},
{
name: "anthropic native path keeps v1",
adapter: anthropic.NewAdapter(),
clientProtocol: "anthropic",
providerProtocol: "anthropic",
baseURL: "http://example.com",
nativePath: "/v1/messages",
expectedURL: "http://example.com/v1/messages",
body: []byte(`{"model":"claude","messages":[]}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(tt.adapter))
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
URL: tt.nativePath,
Method: "POST",
Body: tt.body,
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
require.NoError(t, err)
assert.Equal(t, tt.expectedURL, out.URL)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap"
) )
func TestConversionError_WithProviderProtocol(t *testing.T) { func TestConversionError_WithProviderProtocol(t *testing.T) {
@@ -39,7 +40,7 @@ func TestConversionError_FullBuilder(t *testing.T) {
func TestEngine_Use(t *testing.T) { func TestEngine_Use(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
called := false called := false
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) { engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
called = true called = true
@@ -66,7 +67,7 @@ func TestEngine_Use(t *testing.T) {
func TestConvertHttpRequest_DecodeError(t *testing.T) { func TestConvertHttpRequest_DecodeError(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) { clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
return nil, errors.New("decode failed") return nil, errors.New("decode failed")
@@ -82,7 +83,7 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) {
func TestConvertHttpRequest_EncodeError(t *testing.T) { func TestConvertHttpRequest_EncodeError(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("client", false)) _ = engine.RegisterAdapter(newMockAdapter("client", false))
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) { providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
@@ -98,7 +99,7 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) {
func TestConvertHttpResponse_CrossProtocol(t *testing.T) { func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) { clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
@@ -121,7 +122,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
func TestConvertHttpResponse_DecodeError(t *testing.T) { func TestConvertHttpResponse_DecodeError(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) { providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return nil, errors.New("decode error") return nil, errors.New("decode error")
@@ -135,7 +136,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) { func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeEmbeddings clientAdapter.ifaceType = InterfaceTypeEmbeddings
@@ -158,7 +159,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
func TestConvertHttpRequest_RerankInterface(t *testing.T) { func TestConvertHttpRequest_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeRerank clientAdapter.ifaceType = InterfaceTypeRerank
@@ -178,7 +179,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) {
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) { func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true} clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
@@ -196,7 +197,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
func TestConvertHttpResponse_RerankInterface(t *testing.T) { func TestConvertHttpResponse_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true} clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
@@ -214,7 +215,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) { func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeModels clientAdapter.ifaceType = InterfaceTypeModels
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)
@@ -232,7 +233,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
func TestConvertHttpResponse_ModelsInterface(t *testing.T) { func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true} clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)
@@ -249,7 +250,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) { func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true} clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)
@@ -324,7 +325,7 @@ var _ = json.Marshal
func TestConvertEmbeddingBody_DecodeError(t *testing.T) { func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) { clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
@@ -344,7 +345,7 @@ func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
func TestConvertRerankBody_DecodeError(t *testing.T) { func TestConvertRerankBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) { clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
@@ -364,7 +365,7 @@ func TestConvertRerankBody_DecodeError(t *testing.T) {
func TestConvertBody_UnknownInterfaceType(t *testing.T) { func TestConvertBody_UnknownInterfaceType(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)

View File

@@ -2,6 +2,7 @@ package conversion
import ( import (
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"nex/backend/internal/conversion/canonical" "nex/backend/internal/conversion/canonical"
@@ -190,7 +191,9 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
// noopStreamDecoder 空流式解码器 // noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{} type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器 // noopStreamEncoder 空流式编码器
@@ -203,7 +206,7 @@ func (e *noopStreamEncoder) Flush() [][]byte
func TestNewConversionEngine(t *testing.T) { func TestNewConversionEngine(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
assert.NotNil(t, engine) assert.NotNil(t, engine)
assert.Equal(t, registry, engine.GetRegistry()) assert.Equal(t, registry, engine.GetRegistry())
} }
@@ -211,7 +214,7 @@ func TestNewConversionEngine(t *testing.T) {
func TestNewConversionEngine_LoggerInjection(t *testing.T) { func TestNewConversionEngine_LoggerInjection(t *testing.T) {
t.Run("nil_logger_uses_global", func(t *testing.T) { t.Run("nil_logger_uses_global", func(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
assert.NotNil(t, engine.logger) assert.NotNil(t, engine.logger)
}) })
@@ -219,13 +222,14 @@ func TestNewConversionEngine_LoggerInjection(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
customLogger := zap.NewNop() customLogger := zap.NewNop()
engine := NewConversionEngine(registry, customLogger) engine := NewConversionEngine(registry, customLogger)
assert.Equal(t, customLogger, engine.logger) assert.NotNil(t, engine.logger)
assert.Contains(t, engine.logger.Name(), "conversion.engine")
}) })
} }
func TestRegisterAdapter(t *testing.T) { func TestRegisterAdapter(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("test-proto", true) adapter := newMockAdapter("test-proto", true)
err := engine.RegisterAdapter(adapter) err := engine.RegisterAdapter(adapter)
@@ -237,7 +241,7 @@ func TestRegisterAdapter(t *testing.T) {
func TestIsPassthrough_SameProtocol(t *testing.T) { func TestIsPassthrough_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("openai", true) adapter := newMockAdapter("openai", true)
_ = engine.RegisterAdapter(adapter) _ = engine.RegisterAdapter(adapter)
@@ -246,7 +250,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
func TestIsPassthrough_DifferentProtocol(t *testing.T) { func TestIsPassthrough_DifferentProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true)) _ = engine.RegisterAdapter(newMockAdapter("openai", true))
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true)) _ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
@@ -255,7 +259,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
func TestIsPassthrough_NoPassthrough(t *testing.T) { func TestIsPassthrough_NoPassthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("custom", false)) _ = engine.RegisterAdapter(newMockAdapter("custom", false))
assert.False(t, engine.IsPassthrough("custom", "custom")) assert.False(t, engine.IsPassthrough("custom", "custom"))
@@ -263,7 +267,7 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
func TestDetectInterfaceType(t *testing.T) { func TestDetectInterfaceType(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("test", true) adapter := newMockAdapter("test", true)
adapter.ifaceType = InterfaceTypeChat adapter.ifaceType = InterfaceTypeChat
_ = engine.RegisterAdapter(adapter) _ = engine.RegisterAdapter(adapter)
@@ -275,7 +279,7 @@ func TestDetectInterfaceType(t *testing.T) {
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) { func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent") _, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
assert.Error(t, err) assert.Error(t, err)
@@ -283,25 +287,39 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
func TestConvertHttpRequest_Passthrough(t *testing.T) { func TestConvertHttpRequest_Passthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true)) openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
}
_ = engine.RegisterAdapter(openaiAdapter)
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4") provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
spec := HTTPRequestSpec{ spec := HTTPRequestSpec{
URL: "/chat/completions", URL: "/v1/chat/completions",
Method: "POST", Method: "POST",
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`), Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
} }
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider) result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL) assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
assert.Equal(t, spec.Body, result.Body) assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
} }
func TestConvertHttpRequest_CrossProtocol(t *testing.T) { func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client-proto", false) clientAdapter := newMockAdapter("client-proto", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) { clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
@@ -331,9 +349,80 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
assert.NotNil(t, result.Body) assert.NotNil(t, result.Body)
} }
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `"}`), nil
}
require.NoError(t, registry.Register(openaiAdapter))
anthropicAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("anthropic", false),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/v1/messages"
}
return nativePath
},
}
anthropicAdapter.ifaceType = InterfaceTypeChat
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
require.NoError(t, registry.Register(anthropicAdapter))
t.Run("OpenAI to Anthropic", func(t *testing.T) {
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
spec := HTTPRequestSpec{
URL: "/v1/chat/completions",
Method: "POST",
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
}
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
})
t.Run("Anthropic to OpenAI", func(t *testing.T) {
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
spec := HTTPRequestSpec{
URL: "/v1/messages",
Method: "POST",
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
}
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
})
}
type buildURLMockAdapter struct {
*mockProtocolAdapter
buildURLFn func(string, InterfaceType) string
}
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
if m.buildURLFn != nil {
return m.buildURLFn(nativePath, interfaceType)
}
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
}
func TestConvertHttpResponse_Passthrough(t *testing.T) { func TestConvertHttpResponse_Passthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true)) _ = engine.RegisterAdapter(newMockAdapter("openai", true))
spec := HTTPResponseSpec{ spec := HTTPResponseSpec{
@@ -349,7 +438,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Passthrough(t *testing.T) { func TestCreateStreamConverter_Passthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true)) _ = engine.RegisterAdapter(newMockAdapter("openai", true))
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat) converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
@@ -360,7 +449,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Canonical(t *testing.T) { func TestCreateStreamConverter_Canonical(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("client", false)) _ = engine.RegisterAdapter(newMockAdapter("client", false))
_ = engine.RegisterAdapter(newMockAdapter("provider", false)) _ = engine.RegisterAdapter(newMockAdapter("provider", false))
@@ -372,7 +461,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
func TestEncodeError(t *testing.T) { func TestEncodeError(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true)) _ = engine.RegisterAdapter(newMockAdapter("openai", true))
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
@@ -384,7 +473,7 @@ func TestEncodeError(t *testing.T) {
func TestEncodeError_NonExistentProtocol(t *testing.T) { func TestEncodeError_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误") convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
body, statusCode, err := engine.EncodeError(convErr, "nonexistent") body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
@@ -417,7 +506,7 @@ func TestRegistry_GetNonExistent(t *testing.T) {
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) { func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false) clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) { clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
@@ -446,7 +535,7 @@ func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) { func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写 // 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
openaiAdapter := newMockAdapter("openai", true) openaiAdapter := newMockAdapter("openai", true)
@@ -476,7 +565,7 @@ func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) { func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := newMockAdapter("openai", true) openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) { openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
@@ -495,18 +584,19 @@ func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
_, ok := converter.(*SmartPassthroughStreamConverter) _, ok := converter.(*SmartPassthroughStreamConverter)
assert.True(t, ok) assert.True(t, ok)
// 验证 chunk 改写 // 验证 SSE frame 中的 data JSON 被改写
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`)) chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
require.Len(t, chunks, 1) require.Len(t, chunks, 1)
var resp map[string]interface{} var resp map[string]interface{}
require.NoError(t, json.Unmarshal(chunks[0], &resp)) payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
assert.Equal(t, "openai/gpt-4", resp["model"]) assert.Equal(t, "openai/gpt-4", resp["model"])
} }
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) { func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
// provider adapter 解码出含 model 的流式事件 // provider adapter 解码出含 model 的流式事件
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)
@@ -560,7 +650,7 @@ func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) { func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil) engine := NewConversionEngine(registry, zap.NewNop())
providerAdapter := newMockAdapter("provider", false) providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder { providerAdapter.streamDecoderFn = func() StreamDecoder {
@@ -614,6 +704,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
} }
return nil return nil
} }
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent { func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil { if d.flushFn != nil {
return d.flushFn() return d.flushFn()
@@ -633,6 +724,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
} }
return nil return nil
} }
func (e *engineTestStreamEncoder) Flush() [][]byte { func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil { if e.flushFn != nil {
return e.flushFn() return e.flushFn()

View File

@@ -17,6 +17,13 @@ const (
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION" ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE" ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED" ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
)
const (
ErrorDetailPhase = "phase"
ErrorPhaseRequest = "request"
ErrorPhaseResponse = "response"
) )
// ConversionError 协议转换错误 // ConversionError 协议转换错误
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
} }
} }
// NewRequestJSONParseError 创建请求 JSON 解析错误。
func NewRequestJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
WithCause(cause)
}
// NewResponseJSONParseError 创建响应 JSON 解析错误。
func NewResponseJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
WithCause(cause)
}
// WithClientProtocol 设置客户端协议 // WithClientProtocol 设置客户端协议
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError { func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
e.ClientProtocol = protocol e.ClientProtocol = protocol

View File

@@ -29,27 +29,27 @@ func (a *Adapter) SupportsPassthrough() bool { return true }
// DetectInterfaceType 根据路径检测接口类型 // DetectInterfaceType 根据路径检测接口类型
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType { func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
switch { switch {
case nativePath == "/chat/completions": case nativePath == "/v1/chat/completions":
return conversion.InterfaceTypeChat return conversion.InterfaceTypeChat
case nativePath == "/models": case nativePath == "/v1/models":
return conversion.InterfaceTypeModels return conversion.InterfaceTypeModels
case isModelInfoPath(nativePath): case isModelInfoPath(nativePath):
return conversion.InterfaceTypeModelInfo return conversion.InterfaceTypeModelInfo
case nativePath == "/embeddings": case nativePath == "/v1/embeddings":
return conversion.InterfaceTypeEmbeddings return conversion.InterfaceTypeEmbeddings
case nativePath == "/rerank": case nativePath == "/v1/rerank":
return conversion.InterfaceTypeRerank return conversion.InterfaceTypeRerank
default: default:
return conversion.InterfaceTypePassthrough return conversion.InterfaceTypePassthrough
} }
} }
// isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 / // isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool { func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/models/") { if !strings.HasPrefix(path, "/v1/models/") {
return false return false
} }
suffix := path[len("/models/"):] suffix := path[len("/v1/models/"):]
return suffix != "" return suffix != ""
} }
@@ -60,6 +60,11 @@ func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.Interface
return "/chat/completions" return "/chat/completions"
case conversion.InterfaceTypeModels: case conversion.InterfaceTypeModels:
return "/models" return "/models"
case conversion.InterfaceTypeModelInfo:
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
return "/models/" + modelID
}
return nativePath
case conversion.InterfaceTypeEmbeddings: case conversion.InterfaceTypeEmbeddings:
return "/embeddings" return "/embeddings"
case conversion.InterfaceTypeRerank: case conversion.InterfaceTypeRerank:
@@ -138,7 +143,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code), Code: string(err.Code),
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -218,12 +226,12 @@ func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse)
return encodeRerankResponse(resp) return encodeRerankResponse(resp)
} }
// ExtractUnifiedModelID 从路径中提取统一模型 ID/models/{provider_id}/{model_name} // ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) { func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/models/") { if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath) return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
} }
suffix := nativePath[len("/models/"):] suffix := nativePath[len("/v1/models/"):]
if suffix == "" { if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID") return "", fmt.Errorf("路径缺少模型 ID")
} }
@@ -248,7 +256,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -282,12 +294,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings: case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加 // Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
case conversion.InterfaceTypeRerank: case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加 // Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists { if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
} }
return json.Marshal(m) return json.Marshal(m)
default: default:

View File

@@ -28,11 +28,11 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
path string path string
expected conversion.InterfaceType expected conversion.InterfaceType
}{ }{
{"聊天补全", "/chat/completions", conversion.InterfaceTypeChat}, {"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
{"模型列表", "/models", conversion.InterfaceTypeModels}, {"模型列表", "/v1/models", conversion.InterfaceTypeModels},
{"模型详情", "/models/gpt-4", conversion.InterfaceTypeModelInfo}, {"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
{"嵌入接口", "/embeddings", conversion.InterfaceTypeEmbeddings}, {"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
{"重排序接口", "/rerank", conversion.InterfaceTypeRerank}, {"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough}, {"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
} }
@@ -44,6 +44,27 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
} }
} }
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
a := NewAdapter()
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/chat/completions", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
{"/embeddings", conversion.InterfaceTypePassthrough},
{"/rerank", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) { func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
@@ -53,10 +74,12 @@ func TestAdapter_BuildUrl(t *testing.T) {
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected string expected string
}{ }{
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"}, {"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/models", conversion.InterfaceTypeModels, "/models"}, {"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
{"嵌入", "/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"}, {"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
{"重排序", "/rerank", conversion.InterfaceTypeRerank, "/rerank"}, {"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"}, {"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
} }
@@ -118,12 +141,12 @@ func TestIsModelInfoPath(t *testing.T) {
path string path string
expected bool expected bool
}{ }{
{"model_info", "/models/gpt-4", true}, {"model_info", "/v1/models/openai/gpt-4", true},
{"model_info_with_dots", "/models/gpt-4.1-preview", true}, {"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
{"models_list", "/models", false}, {"models_list", "/v1/models", false},
{"nested_path", "/models/gpt-4/versions", true}, {"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
{"empty_suffix", "/models/", false}, {"empty_suffix", "/v1/models/", false},
{"unrelated", "/chat/completions", false}, {"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/model", false}, {"partial_prefix", "/model", false},
} }
@@ -134,6 +157,27 @@ func TestIsModelInfoPath(t *testing.T) {
} }
} }
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("标准路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", modelID)
})
t.Run("复杂路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
})
t.Run("非模型详情路径报错", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
}
func TestAdapter_EncodeError_InvalidInput(t *testing.T) { func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
a := NewAdapter() a := NewAdapter()
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效") convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")

View File

@@ -18,35 +18,35 @@ func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter() a := NewAdapter()
t.Run("standard_path", func(t *testing.T) { t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/openai/gpt-4") id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", id) assert.Equal(t, "openai/gpt-4", id)
}) })
t.Run("multi_segment_path", func(t *testing.T) { t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/azure/accounts/org/models/gpt-4") id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", id) assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
}) })
t.Run("single_segment", func(t *testing.T) { t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/gpt-4") id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "gpt-4", id) assert.Equal(t, "gpt-4", id)
}) })
t.Run("non_model_path", func(t *testing.T) { t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/chat/completions") _, err := a.ExtractUnifiedModelID("/v1/chat/completions")
require.Error(t, err) require.Error(t, err)
}) })
t.Run("empty_suffix", func(t *testing.T) { t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/models/") _, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err) require.Error(t, err)
}) })
t.Run("models_list_no_slash", func(t *testing.T) { t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/models") _, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err) require.Error(t, err)
}) })
@@ -344,12 +344,12 @@ func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
path string path string
expected bool expected bool
}{ }{
{"simple_model_id", "/models/gpt-4", true}, {"simple_model_id", "/v1/models/gpt-4", true},
{"unified_model_id_with_slash", "/models/openai/gpt-4", true}, {"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
{"models_list", "/models", false}, {"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/models/", false}, {"models_list_trailing_slash", "/v1/models/", false},
{"chat_completions", "/chat/completions", false}, {"chat_completions", "/v1/chat/completions", false},
{"deeply_nested", "/models/azure/eastus/deployments/my-dept/models/gpt-4", true}, {"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text)) blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url": case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"}) blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart var result []contentPart
for _, item := range parts { for _, item := range parts {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text}) result = append(result, contentPart{Type: "text", Text: text})
case "refusal": case "refusal":
refusal, _ := m["refusal"].(string) refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal}) result = append(result, contentPart{Type: "refusal", Refusal: refusal})
} }
} }
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "function": case "function":
if fn, ok := v["function"].(map[string]any); ok { if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string) name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "custom": case "custom":
if custom, ok := v["custom"].(map[string]any); ok { if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string) name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "allowed_tools": case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok { if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string) mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" { if mode == "required" {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2) assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any) firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"]) assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"]) assert.Equal(t, "你是助手", firstMsg["content"])
} }
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any) require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any) toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, toolCalls, 1) assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any) tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"]) assert.Equal(t, "call_1", tc["id"])
} }
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"]) assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"]) assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any) choices, ok := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, ok)
msg := choice["message"].(map[string]any) choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"]) assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"]) assert.Equal(t, "stop", choice["finish_reason"])
} }
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okc := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any) tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, tcs, 1) assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"]) assert.Equal(t, "list", result["object"])
data := result["data"].([]any) data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
} }
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"]) assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"]) assert.Equal(t, "思考过程", msg["reasoning_content"])
} }

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ") data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n") data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload)) require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any) choices, okch := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"]) assert.Equal(t, "assistant", delta["role"])
} }

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"]) assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any) results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1) assert.Len(t, results, 1)
} }
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"]) assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any) ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok) require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"]) assert.Equal(t, tt.want, choice["finish_reason"])
}) })
} }

View File

@@ -1,6 +1,11 @@
package conversion package conversion
import "nex/backend/internal/conversion/canonical" import (
"bytes"
"strings"
"nex/backend/internal/conversion/canonical"
)
// StreamDecoder 流式解码器接口 // StreamDecoder 流式解码器接口
type StreamDecoder interface { type StreamDecoder interface {
@@ -39,11 +44,12 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
} }
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器 // SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
// 逐 chunk 改写 model 字段 // 按 SSE frame 改写 data JSON 中的 model 字段
type SmartPassthroughStreamConverter struct { type SmartPassthroughStreamConverter struct {
adapter ProtocolAdapter adapter ProtocolAdapter
modelOverride string modelOverride string
interfaceType InterfaceType interfaceType InterfaceType
buffer []byte
} }
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器 // NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
@@ -55,24 +61,45 @@ func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride s
} }
} }
// ProcessChunk 改写 chunk 中的 model 字段 // ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
if len(rawChunk) == 0 { if len(rawChunk) == 0 {
return nil return nil
} }
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType) c.buffer = append(c.buffer, rawChunk...)
if err != nil { frames, rest := splitSSEFrames(c.buffer)
// 改写失败,返回原始 chunk c.buffer = rest
return [][]byte{rawChunk}
}
return [][]byte{rewrittenChunk} result := make([][]byte, 0, len(frames))
for _, frame := range frames {
result = append(result, c.rewriteFrame(frame))
}
return result
} }
// Flush 无缓冲数据 func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
payload, ok := sseFrameDataPayload(frame)
if !ok || strings.TrimSpace(payload) == "[DONE]" {
return frame
}
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
if err != nil {
return frame
}
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
}
// Flush 输出未形成完整 frame 的剩余数据
func (c *SmartPassthroughStreamConverter) Flush() [][]byte { func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
if len(c.buffer) == 0 {
return nil return nil
}
frame := append([]byte(nil), c.buffer...)
c.buffer = nil
return [][]byte{c.rewriteFrame(frame)}
} }
// CanonicalStreamConverter 跨协议规范流式转换器 // CanonicalStreamConverter 跨协议规范流式转换器
@@ -153,3 +180,86 @@ func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.Canonical
event.Message.Model = c.modelOverride event.Message.Model = c.modelOverride
} }
} }
func splitSSEFrames(data []byte) ([][]byte, []byte) {
var frames [][]byte
for len(data) > 0 {
idx, sepLen := findSSEFrameSeparator(data)
if idx < 0 {
break
}
end := idx + sepLen
frames = append(frames, append([]byte(nil), data[:end]...))
data = data[end:]
}
return frames, data
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
lineEnding, separator := sseLineEnding(frame)
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
out := make([]string, 0, len(lines)+1)
dataWritten := false
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
if !dataWritten {
for _, dataLine := range strings.Split(data, "\n") {
out = append(out, "data: "+dataLine)
}
dataWritten = true
}
continue
}
out = append(out, line)
}
if !dataWritten {
out = append(out, "data: "+data)
}
return []byte(strings.Join(out, lineEnding) + separator)
}
func sseLineEnding(frame []byte) (string, string) {
if bytes.Contains(frame, []byte("\r\n")) {
return "\r\n", "\r\n\r\n"
}
return "\n", "\n\n"
}

View File

@@ -2,7 +2,6 @@ package database
import ( import (
"fmt" "fmt"
"log"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@@ -12,22 +11,24 @@ import (
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
"nex/backend/internal/config" "nex/backend/internal/config"
pkglogger "nex/backend/pkg/logger"
) )
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) { func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
db, err := initDB(cfg) moduleLogger := pkglogger.WithModule(zapLogger, "database")
db, err := initDB(cfg, moduleLogger)
if err != nil { if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err) return nil, fmt.Errorf("初始化数据库失败: %w", err)
} }
if err := runMigrations(db, cfg.Driver); err != nil { if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err) return nil, fmt.Errorf("数据库迁移失败: %w", err)
} }
configurePool(db, cfg) configurePool(db, cfg, moduleLogger)
return db, nil return db, nil
} }
@@ -40,36 +41,42 @@ func Close(db *gorm.DB) {
sqlDB.Close() sqlDB.Close()
} }
func initDB(cfg *config.DatabaseConfig) (*gorm.DB, error) { func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
gormLogger := pkglogger.NewGormLogger(zapLogger)
gormConfig := &gorm.Config{ gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info), Logger: gormLogger,
} }
switch cfg.Driver { switch cfg.Driver {
case "mysql": case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName) cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
if zapLogger != nil {
zapLogger.Info("连接 MySQL 数据库",
zap.String("host", cfg.Host),
zap.Int("port", cfg.Port),
zap.String("database", cfg.DBName))
}
return gorm.Open(mysql.Open(dsn), gormConfig) return gorm.Open(mysql.Open(dsn), gormConfig)
default: default:
dbDir := filepath.Dir(cfg.Path) dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil { if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err) return nil, fmt.Errorf("创建数据库目录失败: %w", err)
} }
if zapLogger != nil {
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
}
return gorm.Open(sqlite.Open(cfg.Path), gormConfig) return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
} }
} }
func runMigrations(db *gorm.DB, driver string) error { func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
sqlDB, err := db.DB() sqlDB, err := db.DB()
if err != nil { if err != nil {
return err return err
} }
migrationsDir := getMigrationsDir(driver)
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
gooseDialect := "sqlite3" gooseDialect := "sqlite3"
migrationsSubDir := "sqlite" migrationsSubDir := "sqlite"
if driver == "mysql" { if driver == "mysql" {
@@ -77,19 +84,33 @@ func runMigrations(db *gorm.DB, driver string) error {
migrationsSubDir = "mysql" migrationsSubDir = "mysql"
} }
goose.SetDialect(gooseDialect) migrationsDir := getMigrationsDir(driver)
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
if zapLogger != nil {
zapLogger.Info("执行数据库迁移",
zap.String("dialect", gooseDialect),
zap.String("dir", migrationsSubDir))
}
if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil { if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err return err
} }
log.Printf("使用 %s 方言执行迁移,目录: %s", gooseDialect, migrationsSubDir)
return nil return nil
} }
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) { func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
if cfg.Driver == "sqlite" { if cfg.Driver == "sqlite" {
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err) if zapLogger != nil {
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
}
} }
} }
@@ -101,8 +122,12 @@ func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) {
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime) sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v", if zapLogger != nil {
cfg.MaxIdleConns, cfg.MaxOpenConns, cfg.ConnMaxLifetime) zapLogger.Info("数据库连接池配置",
zap.Int("max_idle_conns", cfg.MaxIdleConns),
zap.Int("max_open_conns", cfg.MaxOpenConns),
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
}
} }
func getMigrationsDir(driver string) string { func getMigrationsDir(driver string) string {

View File

@@ -4,10 +4,11 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/config"
) )
func TestInit_SQLite(t *testing.T) { func TestInit_SQLite(t *testing.T) {
@@ -20,7 +21,8 @@ func TestInit_SQLite(t *testing.T) {
ConnMaxLifetime: 0, ConnMaxLifetime: 0,
} }
db, err := Init(cfg, nil) zapLogger := zap.NewNop()
db, err := Init(cfg, zapLogger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, db) require.NotNil(t, db)
defer Close(db) defer Close(db)
@@ -40,7 +42,8 @@ func TestClose(t *testing.T) {
ConnMaxLifetime: 0, ConnMaxLifetime: 0,
} }
db, err := Init(cfg, nil) zapLogger := zap.NewNop()
db, err := Init(cfg, zapLogger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, db) require.NotNil(t, db)

View File

@@ -13,4 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }

View File

@@ -6,13 +6,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
) )
func TestProviderHandler_CreateProvider_Success(t *testing.T) { func TestProviderHandler_CreateProvider_Success(t *testing.T) {

View File

@@ -9,23 +9,22 @@ import (
"strings" "strings"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()

View File

@@ -5,9 +5,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
) )
// Logging 日志中间件
func Logging(logger *zap.Logger) gin.HandlerFunc { func Logging(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
start := time.Now() start := time.Now()
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
query := c.Request.URL.RawQuery query := c.Request.URL.RawQuery
requestID, _ := c.Get(RequestIDKey) requestID, _ := c.Get(RequestIDKey)
logger.Info("请求开始", var requestIDStr string
zap.String("method", c.Request.Method), if id, ok := requestID.(string); ok {
zap.String("path", path), requestIDStr = id
zap.String("query", query), }
zap.String("client_ip", c.ClientIP()), logger.Debug("请求开始",
zap.Any("request_id", requestID), pkglogger.Method(c.Request.Method),
pkglogger.Path(path),
pkglogger.Query(query),
pkglogger.ClientIP(c.ClientIP()),
pkglogger.RequestID(requestIDStr),
) )
c.Next() c.Next()
@@ -28,13 +33,13 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
latency := time.Since(start) latency := time.Since(start)
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
logger.Info("请求结束", logger.Debug("请求结束",
zap.Int("status", statusCode), pkglogger.StatusCode(statusCode),
zap.String("method", c.Request.Method), pkglogger.Method(c.Request.Method),
zap.String("path", path), pkglogger.Path(path),
zap.Duration("latency", latency), pkglogger.Latency(latency),
zap.Int("body_size", c.Writer.Size()), pkglogger.BodySize(c.Writer.Size()),
zap.Any("request_id", requestID), pkglogger.RequestID(requestIDStr),
) )
} }
} }

View File

@@ -7,6 +7,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
) )
func init() { func init() {
@@ -65,6 +67,61 @@ func TestLogging(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) {
core, logs := observer.New(zapcore.InfoLevel)
logger := zap.New(core)
w := serveLoggingRequest(logger)
assert.Equal(t, 200, w.Code)
assert.Empty(t, logs.FilterMessage("请求开始").All())
assert.Empty(t, logs.FilterMessage("请求结束").All())
}
func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) {
core, logs := observer.New(zapcore.DebugLevel)
logger := zap.New(core)
w := serveLoggingRequest(logger)
assert.Equal(t, 200, w.Code)
startLogs := logs.FilterMessage("请求开始").All()
endLogs := logs.FilterMessage("请求结束").All()
if assert.Len(t, startLogs, 1) {
fields := startLogs[0].ContextMap()
assert.Equal(t, "GET", fields["method"])
assert.Equal(t, "/test", fields["path"])
assert.Equal(t, "key=value", fields["query"])
assert.Equal(t, "existing-id-123", fields["request_id"])
assert.NotEmpty(t, fields["client_ip"])
}
if assert.Len(t, endLogs, 1) {
fields := endLogs[0].ContextMap()
assert.Equal(t, int64(200), fields["status"])
assert.Equal(t, "GET", fields["method"])
assert.Equal(t, "/test", fields["path"])
assert.Equal(t, int64(2), fields["body_size"])
assert.Equal(t, "existing-id-123", fields["request_id"])
assert.Contains(t, fields, "latency")
}
}
func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder {
r := gin.New()
r.Use(RequestID())
r.Use(Logging(logger))
r.GET("/test", func(c *gin.Context) {
c.String(200, "ok")
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test?key=value", nil)
req.Header.Set("X-Request-ID", "existing-id-123")
r.ServeHTTP(w, req)
return w
}
func TestRecovery_NoPanic(t *testing.T) { func TestRecovery_NoPanic(t *testing.T) {
logger := zap.NewNop() logger := zap.NewNop()

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ModelHandler 模型管理处理器 // ModelHandler 模型管理处理器
@@ -58,13 +58,13 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
err := h.modelService.Create(model) err := h.modelService.Create(model)
if err != nil { if err != nil {
if err == appErrors.ErrProviderNotFound { if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在", "error": "供应商不存在",
}) })
return return
} }
if err == appErrors.ErrDuplicateModel { if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{ c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在", "error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code, "code": appErrors.ErrDuplicateModel.Code,
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id) model, err := h.modelService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id) err := h.modelService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ProviderHandler 供应商管理处理器 // ProviderHandler 供应商管理处理器
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider) err := h.providerService.Create(provider)
if err != nil { if err != nil {
if err == appErrors.ErrInvalidProviderID { if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message, "error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code, "code": appErrors.ErrInvalidProviderID.Code,
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
provider, err := h.providerService.Get(id) provider, err := h.providerService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req) err := h.providerService.Update(id, req)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id) err := h.providerService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })

View File

@@ -3,19 +3,23 @@ package handler
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical" "nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/service" "nex/backend/internal/service"
appErrors "nex/backend/pkg/errors"
"nex/backend/pkg/modelid" "nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
) )
// ProxyHandler 统一代理处理器 // ProxyHandler 统一代理处理器
@@ -29,14 +33,14 @@ type ProxyHandler struct {
} }
// NewProxyHandler 创建统一代理处理器 // NewProxyHandler 创建统一代理处理器
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler { func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
return &ProxyHandler{ return &ProxyHandler{
engine: engine, engine: engine,
client: client, client: client,
routingService: routingService, routingService: routingService,
providerService: providerService, providerService: providerService,
statsService: statsService, statsService: statsService,
logger: zap.L(), logger: pkglogger.WithModule(logger, "handler.proxy"),
} }
} }
@@ -45,7 +49,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/... // 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol") clientProtocol := c.Param("protocol")
if clientProtocol == "" { if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"}) h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
return return
} }
@@ -55,12 +59,13 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
path = "/" + path path = "/" + path
} }
nativePath := path nativePath := path
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
// 获取 client adapter // 获取 client adapter
registry := h.engine.GetRegistry() registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol) clientAdapter, err := registry.Get(clientProtocol)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return return
} }
@@ -77,7 +82,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
if ifaceType == conversion.InterfaceTypeModelInfo { if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath) unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"}) h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
return return
} }
h.handleModelInfo(c, unifiedID, clientAdapter) h.handleModelInfo(c, unifiedID, clientAdapter)
@@ -87,40 +92,50 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
return return
} }
// 解析统一模型 ID使用 adapter.ExtractModelName
var providerID, modelName string
if len(body) > 0 {
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
}
// 构建输入 HTTPRequestSpec // 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{ inSpec := conversion.HTTPRequestSpec{
URL: nativePath, URL: requestPath,
Method: c.Request.Method, Method: c.Request.Method,
Headers: extractHeaders(c), Headers: extractHeaders(c),
Body: body, Body: body,
} }
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err != nil {
if isInvalidJSONError(err) {
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
return
}
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
if providerID == "" || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
// 路由 // 路由
routeResult, err := h.routingService.RouteByModelName(providerID, modelName) routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil { if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游 h.writeRouteError(c, err)
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
return
}
h.writeError(c, err, clientProtocol)
return return
} }
@@ -140,9 +155,6 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
routeResult.Model.ModelName, // 上游模型名,用于请求改写 routeResult.Model.ModelName, // 上游模型名,用于请求改写
) )
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写 // 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID() unifiedModelID := routeResult.Model.UnifiedModelID()
@@ -153,12 +165,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
} }
} }
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
return true
default:
return false
}
}
func isInvalidJSONError(err error) bool {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}
// handleNonStream 处理非流式请求 // handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求 // 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil { if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error())) h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -166,31 +200,27 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求 // 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec) resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error())) h.logger.Error("发送请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return return
} }
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段) // 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil { if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
// 设置响应头 h.writeConvertedResponse(c, *convertedResp)
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
@@ -203,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return return
} }
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段) // 发送流式请求
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType) streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeUpstreamUnavailable(c, err)
return
}
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return return
} }
// 发送流式请求 // 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec) streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
@@ -222,37 +260,61 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer) writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range eventChan { for event := range streamResp.Events {
if event.Error != nil { if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) h.logger.Error("流读取错误", zap.Error(event.Error))
break break
} }
if event.Done { if event.Done {
// flush 转换器 // flush 转换器
chunks := streamConverter.Flush() chunks := streamConverter.Flush()
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush()
} }
flushed = true
break break
} }
chunks := streamConverter.ProcessChunk(event.Data) chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush() break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
} }
} }
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求 // isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat { if ifaceType != conversion.InterfaceTypeChat {
return false return false
} }
@@ -271,8 +333,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 从数据库查询所有启用的模型 // 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels() models, err := h.providerService.ListEnabledModels()
if err != nil { if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error())) h.logger.Error("查询启用模型失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
return return
} }
@@ -293,8 +355,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 使用 adapter 编码返回 // 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList) body, err := adapter.EncodeModelsResponse(modelList)
if err != nil { if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error())) h.logger.Error("编码 Models 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return return
} }
@@ -306,17 +368,14 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// 解析统一模型 ID // 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID) providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{ h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
return return
} }
// 从数据库查询模型 // 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName) model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil { if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"}) h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
return return
} }
@@ -331,42 +390,104 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// 使用 adapter 编码返回 // 使用 adapter 编码返回
body, err := adapter.EncodeModelInfoResponse(modelInfo) body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil { if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error())) h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return return
} }
c.Data(http.StatusOK, "application/json", body) c.Data(http.StatusOK, "application/json", body)
} }
// writeConversionError 写入转换错误 // writeConversionError 写入网关层转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok { var convErr *conversion.ConversionError
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) if errors.As(err, &convErr) {
c.Data(statusCode, "application/json", body) statusCode, code, message := mapConversionError(convErr)
h.writeProxyError(c, statusCode, code, message)
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
} }
// writeError 写入路由错误 func mapConversionError(err *conversion.ConversionError) (int, string, string) {
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) { switch err.Code {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) case conversion.ErrorCodeJSONParseError:
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
}
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
case conversion.ErrorCodeInvalidInput,
conversion.ErrorCodeMissingRequiredField,
conversion.ErrorCodeProtocolConstraint:
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
case conversion.ErrorCodeInterfaceNotSupported:
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
case conversion.ErrorCodeUnsupportedMultimodal:
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
default:
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
}
}
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
switch appErr.Code {
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
default:
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
}
return
}
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
}
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
h.logger.Error("上游不可达", zap.Error(err))
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
}
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": message,
"code": code,
})
}
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range resp.Headers {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, resp.Body)
}
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range filterHopByHopHeaders(resp.Headers) {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
c.Data(resp.StatusCode, contentType, resp.Body)
} }
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求) // forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) { func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
registry := h.engine.GetRegistry() registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol) adapter, err := registry.Get(clientProtocol)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return return
} }
providers, err := h.providerService.List() providers, err := h.providerService.List()
if err != nil || len(providers) == 0 { if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL)) h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"}) h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
return return
} }
@@ -376,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
providerProtocol = "openai" providerProtocol = "openai"
} }
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "") targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol { if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
headers := adapter.BuildHeaders(targetProvider) headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok { if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
} }
outSpec = &conversion.HTTPRequestSpec{ outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL, URL: joinBaseURL(p.BaseURL, upstreamPath),
Method: inSpec.Method, Method: inSpec.Method,
Headers: headers, Headers: headers,
Body: inSpec.Body, Body: inSpec.Body,
@@ -401,9 +521,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
} }
} }
if isStream {
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
return
}
resp, err := h.client.Send(c.Request.Context(), *outSpec) resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return return
} }
@@ -413,13 +542,111 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
return return
} }
for k, v := range convertedResp.Headers { h.writeConvertedResponse(c, *convertedResp)
c.Header(k, v) }
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
if err != nil {
h.writeUpstreamUnavailable(c, err)
return
} }
if c.GetHeader("Content-Type") == "" { if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
c.Header("Content-Type", "application/json") h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
} }
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("透传流读取错误", zap.Error(event.Error))
break
}
if event.Done {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
chunks := streamConverter.ProcessChunk(event.Data)
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
}
}
func stripRawQuery(path string) string {
pathOnly, _, _ := strings.Cut(path, "?")
return pathOnly
}
func rawQueryFromPath(path string) string {
_, rawQuery, found := strings.Cut(path, "?")
if !found {
return ""
}
return rawQuery
}
func joinBaseURL(baseURL, path string) string {
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func headerValue(headers map[string]string, key string) string {
for k, v := range headers {
if strings.EqualFold(k, key) {
return v
}
}
return ""
}
func filterHopByHopHeaders(headers map[string]string) map[string]string {
if len(headers) == 0 {
return nil
}
hopByHop := map[string]struct{}{
"connection": {},
"transfer-encoding": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"upgrade": {},
}
filtered := make(map[string]string, len(headers))
for k, v := range headers {
if _, skip := hopByHop[strings.ToLower(k)]; skip {
continue
}
filtered[k] = v
}
return filtered
} }
// extractHeaders 从 Gin context 提取请求头 // extractHeaders 从 Gin context 提取请求头

View File

@@ -5,33 +5,34 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic" "nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai" "nex/backend/internal/conversion/openai"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks" "nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"go.uber.org/zap"
appErrors "nex/backend/pkg/errors"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine { func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper() t.Helper()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil) engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter())) require.NoError(t, registry.Register(openai.NewAdapter()))
require.NoError(t, registry.Register(anthropic.NewAdapter())) require.NoError(t, registry.Register(anthropic.NewAdapter()))
return engine return engine
@@ -44,6 +45,7 @@ func newTestProxyHandler(engine *conversion.ConversionEngine, client *mocks.Mock
routingSvc, routingSvc,
providerSvc, providerSvc,
statsSvc, statsSvc,
zap.NewNop(),
) )
} }
@@ -72,7 +74,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -91,8 +93,8 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -108,20 +110,20 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(nil, appErrors.ErrModelNotFound) routingSvc.EXPECT().RouteByModelName("unknown", "model").Return(nil, appErrors.ErrModelNotFound)
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return(nil, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc) h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
} }
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) { func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
@@ -130,7 +132,7 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -144,11 +146,12 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
} }
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) { func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
@@ -157,7 +160,7 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -171,11 +174,12 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
} }
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) { func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
@@ -184,12 +188,12 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
@@ -198,7 +202,7 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")} ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true} ch <- provider.StreamEvent{Done: true}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -207,13 +211,14 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Contains(t, w.Body.String(), "Hello") assert.Contains(t, w.Body.String(), "Hello")
assert.Contains(t, w.Body.String(), "p1/gpt-4")
} }
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) { func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
@@ -222,12 +227,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
return nil, context.DeadlineExceeded return nil, context.DeadlineExceeded
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
@@ -236,11 +241,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
} }
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) { func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
@@ -260,8 +266,8 @@ func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -281,11 +287,11 @@ func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/unknown/models", nil) c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 400, w.Code) assert.Equal(t, 404, w.Code)
} }
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) { func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
@@ -303,8 +309,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -328,7 +334,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -346,8 +352,8 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -370,6 +376,7 @@ func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
h.writeConversionError(c, context.DeadlineExceeded, "openai") h.writeConversionError(c, context.DeadlineExceeded, "openai")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.JSONEq(t, `{"error":"context deadline exceeded","code":"CONVERSION_FAILED"}`, w.Body.String())
} }
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) { func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
@@ -389,7 +396,40 @@ func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request") convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
h.writeConversionError(c, convErr, "openai") h.writeConversionError(c, convErr, "openai")
assert.Equal(t, 500, w.Code) assert.Equal(t, 400, w.Code)
assert.JSONEq(t, `{"error":"bad request","code":"INVALID_REQUEST"}`, w.Body.String())
}
func TestProxyHandler_WriteConversionError_JSONPhase(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
t.Run("request json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewRequestJSONParseError("解码请求失败", context.Canceled), "openai")
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
})
t.Run("response json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewResponseJSONParseError("解码响应失败", context.Canceled), "openai")
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"error":"解码响应失败","code":"CONVERSION_FAILED"}`, w.Body.String())
})
} }
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) { func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
@@ -409,8 +449,8 @@ func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -422,19 +462,19 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")} ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")} ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -443,8 +483,8 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -459,12 +499,12 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
@@ -472,7 +512,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")} ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true} ch <- provider.StreamEvent{Done: true}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -481,8 +521,8 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -499,12 +539,12 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil) engine := conversion.NewConversionEngine(registry, zap.NewNop())
err := registry.Register(openai.NewAdapter()) err := registry.Register(openai.NewAdapter())
require.NoError(t, err) require.NoError(t, err)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -515,8 +555,8 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
@@ -527,11 +567,11 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil) engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter())) require.NoError(t, registry.Register(openai.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -542,8 +582,8 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
@@ -554,12 +594,12 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil) engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter())) require.NoError(t, registry.Register(openai.NewAdapter()))
require.NoError(t, registry.Register(anthropic.NewAdapter())) require.NoError(t, registry.Register(anthropic.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
}, nil) }, nil)
@@ -577,8 +617,8 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
@@ -590,7 +630,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -609,8 +649,8 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -623,7 +663,7 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil) engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter())) require.NoError(t, registry.Register(openai.NewAdapter()))
anthropicAdapter := anthropic.NewAdapter() anthropicAdapter := anthropic.NewAdapter()
@@ -641,8 +681,8 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -665,8 +705,8 @@ func TestProxyHandler_ForwardPassthrough_NoBody_NoModel(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -689,10 +729,10 @@ func TestIsStreamRequest_EdgeCases(t *testing.T) {
path string path string
expected bool expected bool
}{ }{
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/chat/completions", true}, {"stream at end of JSON", `{"messages":[],"stream":true}`, "/v1/chat/completions", true},
{"stream with spaces", `{"stream" : true}`, "/chat/completions", true}, {"stream with spaces", `{"stream" : true}`, "/v1/chat/completions", true},
{"stream embedded in string value", `{"model":"stream:true"}`, "/chat/completions", false}, {"stream embedded in string value", `{"model":"stream:true"}`, "/v1/chat/completions", false},
{"empty body", "", "/chat/completions", false}, {"empty body", "", "/v1/chat/completions", false},
{"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false}, {"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false},
} }
@@ -719,8 +759,9 @@ func TestProxyHandler_WriteError_RouteError(t *testing.T) {
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil) c.Request = httptest.NewRequest("POST", "/", nil)
h.writeError(c, fmt.Errorf("model not found"), "openai") h.writeRouteError(c, fmt.Errorf("model not found"))
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
} }
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) { func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
@@ -740,8 +781,8 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -764,35 +805,35 @@ func TestIsStreamRequest(t *testing.T) {
name: "stream true", name: "stream true",
body: []byte(`{"model": "gpt-4", "stream": true}`), body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: true, expected: true,
}, },
{ {
name: "stream false", name: "stream false",
body: []byte(`{"model": "gpt-4", "stream": false}`), body: []byte(`{"model": "gpt-4", "stream": false}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: false, expected: false,
}, },
{ {
name: "no stream field", name: "no stream field",
body: []byte(`{"model": "gpt-4"}`), body: []byte(`{"model": "gpt-4"}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: false, expected: false,
}, },
{ {
name: "invalid json", name: "invalid json",
body: []byte(`{invalid}`), body: []byte(`{invalid}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: false, expected: false,
}, },
{ {
name: "not chat endpoint", name: "not chat endpoint",
body: []byte(`{"model": "gpt-4", "stream": true}`), body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/models", nativePath: "/v1/models",
expected: false, expected: false,
}, },
{ {
@@ -830,8 +871,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -842,7 +883,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
require.True(t, ok) require.True(t, ok)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]interface{}) first, ok2 := data[0].(map[string]interface{})
require.True(t, ok2)
assert.Equal(t, "openai/gpt-4", first["id"]) assert.Equal(t, "openai/gpt-4", first["id"])
} }
@@ -860,8 +902,8 @@ func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/openai/gpt-4"}}
c.Request = httptest.NewRequest("GET", "/openai/models/openai/gpt-4", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -894,8 +936,8 @@ func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testi
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/"}}
c.Request = httptest.NewRequest("GET", "/openai/models/", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -916,7 +958,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var req map[string]interface{} var req map[string]interface{}
json.Unmarshal(spec.Body, &req) require.NoError(t, json.Unmarshal(spec.Body, &req))
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{ return &conversion.HTTPResponseSpec{
@@ -932,8 +974,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -970,8 +1012,8 @@ func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -992,7 +1034,7 @@ func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
@@ -1010,7 +1052,7 @@ data: {"type":"message_stop"}
`)} `)}
ch <- provider.StreamEvent{Done: true} ch <- provider.StreamEvent{Done: true}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -1019,8 +1061,8 @@ data: {"type":"message_stop"}
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -1057,8 +1099,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -1088,8 +1130,8 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
@@ -1098,3 +1140,314 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "error") assert.Contains(t, resp, "error")
} }
func TestProxyHandler_HandleProxy_OpenAIAndAnthropicNativePaths(t *testing.T) {
tests := []struct {
name string
protocol string
path string
requestPath string
baseURL string
expectedURL string
body string
responseBody string
responseModel string
}{
{
name: "openai path keeps v1 after gateway prefix",
protocol: "openai",
path: "/v1/chat/completions",
requestPath: "/openai/v1/chat/completions",
baseURL: "https://api.test.com/v1",
expectedURL: "https://api.test.com/v1/chat/completions",
body: `{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`,
responseModel: "p1/gpt-4",
},
{
name: "anthropic path keeps v1 after gateway prefix",
protocol: "anthropic",
path: "/v1/messages",
requestPath: "/anthropic/v1/messages",
baseURL: "https://api.anthropic.test",
expectedURL: "https://api.anthropic.test/v1/messages",
body: `{"model":"p1/gpt-4","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"msg-1","type":"message","role":"assistant","model":"gpt-4","content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`,
responseModel: "p1/gpt-4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: tt.baseURL, Protocol: tt.protocol, Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, tt.expectedURL, spec.URL)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(tt.responseBody),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: tt.protocol}, {Key: "path", Value: tt.path}}
c.Request = httptest.NewRequest("POST", tt.requestPath, bytes.NewReader([]byte(tt.body)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), tt.responseModel)
})
}
}
func TestProxyHandler_UpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).Return(&conversion.HTTPResponseSpec{
StatusCode: http.StatusTooManyRequests,
Headers: map[string]string{
"Content-Type": "application/json",
"X-Upstream-Error": "rate-limit",
"Transfer-Encoding": "chunked",
},
Body: []byte(`{"error":{"message":"rate limited"}}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusTooManyRequests, w.Code)
assert.JSONEq(t, `{"error":{"message":"rate limited"}}`, w.Body.String())
assert.Equal(t, "rate-limit", w.Header().Get("X-Upstream-Error"))
assert.Empty(t, w.Header().Get("Transfer-Encoding"))
}
func TestProxyHandler_StreamUpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).Return(&provider.StreamResponse{
StatusCode: http.StatusServiceUnavailable,
Headers: map[string]string{"Content-Type": "application/json", "Connection": "close"},
Body: []byte(`{"error":"upstream down"}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusServiceUnavailable, w.Code)
assert.JSONEq(t, `{"error":"upstream down"}`, w.Body.String())
assert.Empty(t, w.Header().Get("Connection"))
}
func TestFilterHopByHopHeaders(t *testing.T) {
filtered := filterHopByHopHeaders(map[string]string{
"Connection": "close",
"Transfer-Encoding": "chunked",
"Keep-Alive": "timeout=5",
"Proxy-Authenticate": "Basic",
"Proxy-Authorization": "Basic token",
"TE": "trailers",
"Trailer": "Expires",
"Upgrade": "websocket",
"Content-Type": "application/json",
"X-Request-ID": "req-1",
})
assert.Equal(t, map[string]string{
"Content-Type": "application/json",
"X-Request-ID": "req-1",
}, filtered)
}
func TestProxyHandler_UnknownInterface_DoesNotGuessModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, "https://api.test.com/v1/unknown?trace=1", spec.URL)
assert.JSONEq(t, `{"model":"p1/gpt-4","payload":true}`, string(spec.Body))
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"ok":true}`),
}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/unknown"}}
c.Request = httptest.NewRequest("POST", "/openai/unknown?trace=1", bytes.NewReader([]byte(`{"model":"p1/gpt-4","payload":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{"ok":true}`, w.Body.String())
}
func TestProxyHandler_InvalidJSON_UsesGatewayError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":`)))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
}
func TestProxyHandler_CrossProtocolMultimodal_Unsupported(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("anthropic_p", "claude").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.test", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"anthropic_p/claude","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "UNSUPPORTED_MULTIMODAL")
}
func TestProxyHandler_SameProtocolMultimodal_SmartPassthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Contains(t, string(spec.Body), "image_url")
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "p1/gpt-4")
}
func TestProxyHandler_RawStreamPassthrough_PreservesSSEFrames(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
ch := make(chan provider.StreamEvent, 3)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"model\":\"gpt-4\",\"choices\":[]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return &provider.StreamResponse{StatusCode: http.StatusOK, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "data: {\"model\":\"gpt-4\",\"choices\":[]}\n\ndata: [DONE]\n\n", w.Body.String())
}

View File

@@ -5,9 +5,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
) )
// StatsHandler 统计处理器 // StatsHandler 统计处理器

View File

@@ -0,0 +1,26 @@
package handler
import (
"net/http"
"nex/backend/pkg/buildinfo"
"github.com/gin-gonic/gin"
)
// VersionHandler 提供后端构建版本信息。
type VersionHandler struct{}
// NewVersionHandler 创建版本信息处理器。
func NewVersionHandler() *VersionHandler {
return &VersionHandler{}
}
// GetVersion 返回构建注入的版本元数据。
func (h *VersionHandler) GetVersion(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"version": buildinfo.Version(),
"commit": buildinfo.Commit(),
"build_time": buildinfo.BuildTime(),
})
}

View File

@@ -0,0 +1,31 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVersionHandler_GetVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewVersionHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/version", nil)
h.GetVersion(c)
assert.Equal(t, http.StatusOK, w.Code)
var result map[string]string
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "dev", result["version"])
assert.Equal(t, "unknown", result["commit"])
assert.Equal(t, "unknown", result["build_time"])
}

View File

@@ -8,6 +8,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"strings"
"syscall" "syscall"
"time" "time"
@@ -15,6 +16,7 @@ import (
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
pkgErrors "nex/backend/pkg/errors" pkgErrors "nex/backend/pkg/errors"
pkglogger "nex/backend/pkg/logger"
) )
// StreamConfig 流式处理配置 // StreamConfig 流式处理配置
@@ -42,6 +44,14 @@ type StreamEvent struct {
Done bool Done bool
} }
// StreamResponse 表示上游流式 HTTP 响应。
type StreamResponse struct {
StatusCode int
Headers map[string]string
Body []byte
Events <-chan StreamEvent
}
// Client 协议无关的供应商客户端 // Client 协议无关的供应商客户端
type Client struct { type Client struct {
httpClient *http.Client httpClient *http.Client
@@ -50,19 +60,20 @@ type Client struct {
} }
// ProviderClient 供应商客户端接口 // ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks //go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface { type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
} }
// NewClient 创建供应商客户端 // NewClient 创建供应商客户端
func NewClient() *Client { func NewClient(logger *zap.Logger) *Client {
return &Client{ return &Client{
httpClient: &http.Client{ httpClient: &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
}, },
logger: zap.L(), logger: pkglogger.WithModule(logger, "provider.client"),
streamCfg: DefaultStreamConfig(), streamCfg: DefaultStreamConfig(),
} }
} }
@@ -114,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
} }
// SendStream 发送流式请求 // SendStream 发送流式请求
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) { func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
var bodyReader io.Reader var bodyReader io.Reader
if len(spec.Body) > 0 { if len(spec.Body) > 0 {
bodyReader = bytes.NewReader(spec.Body) bodyReader = bytes.NewReader(spec.Body)
@@ -137,20 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
return nil, pkgErrors.ErrRequestSend.WithCause(err) return nil, pkgErrors.ErrRequestSend.WithCause(err)
} }
if resp.StatusCode != http.StatusOK { respHeaders := extractResponseHeaders(resp.Header)
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close() defer resp.Body.Close()
cancel() cancel()
errBody, _ := io.ReadAll(resp.Body) errBody, readErr := io.ReadAll(resp.Body)
if len(errBody) > 0 { if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody)) return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
} }
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Body: errBody,
}, nil
} }
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize) eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
go c.readStream(streamCtx, cancel, resp.Body, eventChan) go c.readStream(streamCtx, cancel, resp.Body, eventChan)
return eventChan, nil return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Events: eventChan,
}, nil
} }
// readStream 读取 SSE 流 // readStream 读取 SSE 流
@@ -183,10 +203,10 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
if isNetworkError(err) { if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error())) c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else { } else {
c.logger.Error("流读取错误", zap.String("error", err.Error())) c.logger.Error("流读取错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)} eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
} }
return return
@@ -203,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
} }
for { for {
idx := bytes.Index(dataBuf, []byte("\n\n")) idx, sepLen := findSSEFrameSeparator(dataBuf)
if idx == -1 { if idx == -1 {
break break
} }
rawEvent := dataBuf[:idx] frameEnd := idx + sepLen
dataBuf = dataBuf[idx+2:] rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
dataBuf = dataBuf[frameEnd:]
if bytes.Contains(rawEvent, []byte("data: [DONE]")) { if isSSEDoneFrame(rawEvent) {
eventChan <- StreamEvent{Data: rawEvent}
eventChan <- StreamEvent{Done: true} eventChan <- StreamEvent{Done: true}
return return
} }
@@ -220,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
} }
if err == io.EOF { if err == io.EOF {
if len(dataBuf) > 0 {
eventChan <- StreamEvent{Data: dataBuf}
}
return return
} }
} }
} }
func isSSEDoneFrame(frame []byte) bool {
payload, ok := sseFrameDataPayload(frame)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func extractResponseHeaders(header http.Header) map[string]string {
respHeaders := make(map[string]string)
for k, vs := range header {
if len(vs) > 0 {
respHeaders[k] = vs[0]
}
}
return respHeaders
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
// isNetworkError 判断是否为网络相关错误 // isNetworkError 判断是否为网络相关错误
func isNetworkError(err error) bool { func isNetworkError(err error) bool {
if err == nil { if err == nil {

View File

@@ -13,12 +13,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
) )
func TestNewClient(t *testing.T) { func TestNewClient(t *testing.T) {
client := NewClient() client := NewClient(zap.NewNop())
require.NotNil(t, client) require.NotNil(t, client)
assert.NotNil(t, client.httpClient) assert.NotNil(t, client.httpClient)
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize) assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
@@ -40,11 +41,12 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`)) _, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -64,11 +66,12 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) { func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`)) _, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -82,7 +85,7 @@ func TestClient_Send_ErrorResponse(t *testing.T) {
} }
func TestClient_Send_ConnectionError(t *testing.T) { func TestClient_Send_ConnectionError(t *testing.T) {
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: "http://localhost:1/v1/chat/completions", URL: "http://localhost:1/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -99,7 +102,7 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -107,11 +110,13 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, eventChan) require.NotNil(t, streamResp)
require.Equal(t, http.StatusOK, streamResp.StatusCode)
require.NotNil(t, streamResp.Events)
for range eventChan { for range streamResp.Events {
} }
} }
@@ -121,7 +126,7 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -129,8 +134,10 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
_, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
assert.Error(t, err) require.NoError(t, err)
require.NotNil(t, streamResp)
assert.Equal(t, http.StatusInternalServerError, streamResp.StatusCode)
} }
func TestClient_SendStream_SSEEvents(t *testing.T) { func TestClient_SendStream_SSEEvents(t *testing.T) {
@@ -139,18 +146,21 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")) _, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -158,24 +168,73 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`), Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte var dataEvents [][]byte
var doneEvents int var doneEvents int
for event := range eventChan { for event := range streamResp.Events {
if event.Done { switch {
case event.Done:
doneEvents++ doneEvents++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataEvents = append(dataEvents, event.Data) dataEvents = append(dataEvents, event.Data)
} }
} }
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream") assert.Equal(t, 3, len(dataEvents), "expected 2 data frames plus DONE frame from SSE stream")
assert.Contains(t, string(dataEvents[0]), "Hello") assert.Contains(t, string(dataEvents[0]), "Hello")
assert.Contains(t, string(dataEvents[1]), "World") assert.Contains(t, string(dataEvents[1]), "World")
assert.Contains(t, string(dataEvents[2]), "[DONE]")
assert.Equal(t, 1, doneEvents)
assert.Contains(t, string(dataEvents[0]), "\n\n")
}
func TestClient_SendStream_DONEOnlyWhenDataPayloadEqualsDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
_, err := w.Write([]byte("data: {\"text\":\"data: [DONE] is plain text\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Body: []byte(`{}`),
}
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte
var doneEvents int
for event := range streamResp.Events {
switch {
case event.Done:
doneEvents++
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
default:
dataEvents = append(dataEvents, event.Data)
}
}
require.Len(t, dataEvents, 2)
assert.Contains(t, string(dataEvents[0]), "plain text")
assert.Contains(t, string(dataEvents[1]), "[DONE]")
assert.Equal(t, 1, doneEvents) assert.Equal(t, 1, doneEvents)
} }
@@ -188,7 +247,7 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
defer server.Close() defer server.Close()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -196,13 +255,13 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(ctx, spec) streamResp, err := client.SendStream(ctx, spec)
require.NoError(t, err) require.NoError(t, err)
cancel() cancel()
var gotError bool var gotError bool
for event := range eventChan { for event := range streamResp.Events {
if event.Error != nil { if event.Error != nil {
gotError = true gotError = true
} }
@@ -214,11 +273,12 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method) assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`)) _, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/models", URL: server.URL + "/v1/models",
Method: "GET", Method: "GET",
@@ -237,16 +297,18 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -254,21 +316,22 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
var dataCount int var dataCount int
var doneCount int var doneCount int
for event := range eventChan { for event := range streamResp.Events {
if event.Done { switch {
case event.Done:
doneCount++ doneCount++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataCount++ dataCount++
} }
} }
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE") assert.Equal(t, 2, dataCount, "expected 1 data frame plus DONE frame from slow SSE")
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE") assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
} }
@@ -278,16 +341,18 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -295,19 +360,19 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
var dataEvents int var dataEvents int
var doneEvents int var doneEvents int
for event := range eventChan { for event := range streamResp.Events {
if event.Done { if event.Done {
doneEvents++ doneEvents++
} else { } else {
dataEvents++ dataEvents++
} }
} }
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE") assert.Equal(t, 3, dataEvents, "expected 2 data frames plus DONE frame from split SSE")
assert.Equal(t, 1, doneEvents) assert.Equal(t, 1, doneEvents)
} }
@@ -363,19 +428,20 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok { if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack() conn, _, _ := hijacker.Hijack()
if conn != nil { if conn != nil {
conn.Close() require.NoError(t, conn.Close())
} }
} }
})) }))
defer server.Close() defer server.Close()
client := NewClient() client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{ spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions", URL: server.URL + "/v1/chat/completions",
Method: "POST", Method: "POST",
@@ -383,11 +449,11 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
var gotData bool var gotData bool
for event := range eventChan { for event := range streamResp.Events {
if event.Error != nil { if event.Error != nil {
} else if !event.Done { } else if !event.Done {
gotData = true gotData = true

View File

@@ -3,10 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -3,13 +3,13 @@ package repository
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
) )
func setupTestDB(t *testing.T) *gorm.DB { func setupTestDB(t *testing.T) *gorm.DB {

View File

@@ -1,13 +1,13 @@
package repository package repository
import ( import (
"errors"
"time" "time"
"gorm.io/gorm"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type statsRepository struct { type statsRepository struct {
@@ -19,50 +19,46 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
} }
func (r *statsRepository) Record(providerID, modelName string) error { func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02") now := time.Now()
todayTime, _ := time.Parse("2006-01-02", today) todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
return r.db.Transaction(func(tx *gorm.DB) error { stats := config.UsageStats{
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
stats = config.UsageStats{
ProviderID: providerID, ProviderID: providerID,
ModelName: modelName, ModelName: modelName,
RequestCount: 1, RequestCount: 1,
Date: todayTime, Date: todayTime,
} }
return tx.Create(&stats).Error
} else if err != nil {
return err
}
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error return r.db.Clauses(clause.OnConflict{
}) Columns: []clause.Column{
{Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + 1"),
}),
}).Create(&stats).Error
} }
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error { func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
return r.db.Transaction(func(tx *gorm.DB) error { stats := config.UsageStats{
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, date).First(&stats).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return tx.Create(&config.UsageStats{
ProviderID: providerID, ProviderID: providerID,
ModelName: modelName, ModelName: modelName,
RequestCount: delta, RequestCount: delta,
Date: date, Date: date,
}).Error
} else if err != nil {
return err
} }
return tx.Model(&stats). return r.db.Clauses(clause.OnConflict{
Update("request_count", gorm.Expr("request_count + ?", delta)).Error Columns: []clause.Column{
}) {Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + ?", delta),
}),
}).Create(&stats).Error
} }
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) { func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {

View File

@@ -1,11 +1,15 @@
package service package service
import ( import (
"github.com/google/uuid" "errors"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
) )
type modelService struct { type modelService struct {
@@ -108,8 +112,12 @@ func (s *modelService) Delete(id string) error {
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error { func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复 return nil // 未找到,不重复
} }
return err
}
if excludeID != "" && existing.ID == excludeID { if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身 return nil // 排除自身
} }

View File

@@ -3,10 +3,10 @@ package service
import ( import (
"strings" "strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -4,10 +4,12 @@ import (
"strings" "strings"
"sync" "sync"
"go.uber.org/zap"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
) )
type RoutingCache struct { type RoutingCache struct {
@@ -27,13 +29,15 @@ func NewRoutingCache(
return &RoutingCache{ return &RoutingCache{
modelRepo: modelRepo, modelRepo: modelRepo,
providerRepo: providerRepo, providerRepo: providerRepo,
logger: logger, logger: pkglogger.WithModule(logger, "service.routing_cache"),
} }
} }
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) { func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
provider, err := c.providerRepo.GetByID(id) provider, err := c.providerRepo.GetByID(id)
@@ -42,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
} }
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
c.providers.Store(id, provider) c.providers.Store(id, provider)
@@ -53,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
key := providerID + "/" + modelName key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName) model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
@@ -62,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
} }
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
c.models.Store(key, model) c.models.Store(key, model)
@@ -96,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/" prefix := providerID + "/"
count := 0 count := 0
c.models.Range(func(key, value interface{}) bool { c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) { keyStr, ok := key.(string)
if !ok {
return true
}
if strings.HasPrefix(keyStr, prefix) {
c.models.Delete(key) c.models.Delete(key)
count++ count++
} }

View File

@@ -5,11 +5,11 @@ import (
"sync" "sync"
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockModelRepo struct { type mockModelRepo struct {
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
var openaiCount, anthropicCount int var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool { cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" { keyStr, ok := key.(string)
if ok && keyStr == "anthropic/claude" {
anthropicCount++ anthropicCount++
} }
return true return true

View File

@@ -1,9 +1,8 @@
package service package service
import ( import (
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
) )
type routingService struct { type routingService struct {

View File

@@ -3,11 +3,12 @@ package service
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
) )
func TestProviderService_Update(t *testing.T) { func TestProviderService_Update(t *testing.T) {
@@ -119,7 +120,7 @@ func TestModelService_Delete_NotFound(t *testing.T) {
func TestStatsService_Aggregate_Default(t *testing.T) { func TestStatsService_Aggregate_Default(t *testing.T) {
statsRepo := repository.NewStatsRepository(nil) statsRepo := repository.NewStatsRepository(nil)
buffer := NewStatsBuffer(statsRepo, nil) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer) svc := NewStatsService(statsRepo, buffer)
stats := []domain.UsageStats{ stats := []domain.UsageStats{
@@ -132,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
totalCount := 0 totalCount := 0
for _, r := range result { for _, r := range result {
totalCount += r["request_count"].(int) count, ok := r["request_count"].(int)
require.True(t, ok)
totalCount += count
} }
assert.Equal(t, 15, totalCount) assert.Equal(t, 15, totalCount)
} }

View File

@@ -5,6 +5,9 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -13,8 +16,6 @@ import (
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )
@@ -318,7 +319,8 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "model") result := svc.Aggregate(tt.stats, "model")
@@ -379,7 +381,8 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := setupServiceTestDB(t) db := setupServiceTestDB(t)
statsRepo := repository.NewStatsRepository(db) statsRepo := repository.NewStatsRepository(db)
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer) buffer := NewStatsBuffer(statsRepo, zap.NewNop())
svc := NewStatsService(statsRepo, buffer)
result := svc.Aggregate(tt.stats, "date") result := svc.Aggregate(tt.stats, "date")

View File

@@ -6,9 +6,11 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"nex/backend/internal/repository"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/repository" pkglogger "nex/backend/pkg/logger"
) )
type StatsBuffer struct { type StatsBuffer struct {
@@ -46,7 +48,7 @@ func NewStatsBuffer(
) *StatsBuffer { ) *StatsBuffer {
b := &StatsBuffer{ b := &StatsBuffer{
statsRepo: statsRepo, statsRepo: statsRepo,
logger: logger, logger: pkglogger.WithModule(logger, "service.stats_buffer"),
flushInterval: 5 * time.Second, flushInterval: 5 * time.Second,
flushThreshold: 100, flushThreshold: 100,
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
@@ -66,13 +68,21 @@ func (b *StatsBuffer) Increment(providerID, modelName string) {
var counter *int64 var counter *int64
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter = v.(*int64) if existing, ok := v.(*int64); ok {
counter = existing
} else {
return
}
} else { } else {
val := int64(0) val := int64(0)
counter = &val counter = &val
actual, loaded := b.counters.LoadOrStore(key, counter) actual, loaded := b.counters.LoadOrStore(key, counter)
if loaded { if loaded {
counter = actual.(*int64) existing, ok := actual.(*int64)
if !ok {
return
}
counter = existing
} }
} }
@@ -116,13 +126,20 @@ func (b *StatsBuffer) flush() {
var entries []statEntry var entries []statEntry
b.counters.Range(func(key, value interface{}) bool { b.counters.Range(func(key, value interface{}) bool {
keyStr := key.(string) keyStr, ok := key.(string)
if !ok {
return true
}
parts := strings.Split(keyStr, "/") parts := strings.Split(keyStr, "/")
if len(parts) != 3 { if len(parts) != 3 {
return true return true
} }
counter := value.(*int64) counter, ok := value.(*int64)
if !ok {
return true
}
count := atomic.SwapInt64(counter, 0) count := atomic.SwapInt64(counter, 0)
if count > 0 { if count > 0 {
@@ -142,8 +159,17 @@ func (b *StatsBuffer) flush() {
success := 0 success := 0
for _, entry := range entries { for _, entry := range entries {
date, _ := time.Parse("2006-01-02", entry.date) date, err := time.Parse("2006-01-02", entry.date)
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count)) if err != nil {
b.logger.Error("解析统计日期失败",
zap.String("provider_id", entry.providerID),
zap.String("model_name", entry.modelName),
zap.String("date", entry.date),
zap.Error(err))
continue
}
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
if err != nil { if err != nil {
b.logger.Error("批量更新统计失败", b.logger.Error("批量更新统计失败",
zap.String("provider_id", entry.providerID), zap.String("provider_id", entry.providerID),
@@ -153,9 +179,11 @@ func (b *StatsBuffer) flush() {
key := entry.providerID + "/" + entry.modelName + "/" + entry.date key := entry.providerID + "/" + entry.modelName + "/" + entry.date
if v, ok := b.counters.Load(key); ok { if v, ok := b.counters.Load(key); ok {
counter := v.(*int64) counter, ok := v.(*int64)
if ok {
atomic.AddInt64(counter, entry.count) atomic.AddInt64(counter, entry.count)
} }
}
} else { } else {
success++ success++
} }

View File

@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockStatsRepo struct { type mockStatsRepo struct {
@@ -58,8 +58,10 @@ func TestStatsBuffer_Increment(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
count += atomic.LoadInt64(counter) count += atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(3), count) assert.Equal(t, int64(3), count)
@@ -82,8 +84,10 @@ func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter) count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(100), count) assert.Equal(t, int64(100), count)
@@ -161,8 +165,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var beforeCount int64 var beforeCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
beforeCount = atomic.LoadInt64(counter) beforeCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), beforeCount) assert.Equal(t, int64(2), beforeCount)
@@ -171,8 +177,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
var afterCount int64 var afterCount int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
afterCount = atomic.LoadInt64(counter) afterCount = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(0), afterCount) assert.Equal(t, int64(0), afterCount)
@@ -190,8 +198,10 @@ func TestStatsBuffer_FailRetry(t *testing.T) {
var count int64 var count int64
buffer.counters.Range(func(key, value interface{}) bool { buffer.counters.Range(func(key, value interface{}) bool {
counter := value.(*int64) counter, ok := value.(*int64)
if ok {
count = atomic.LoadInt64(counter) count = atomic.LoadInt64(counter)
}
return true return true
}) })
assert.Equal(t, int64(2), count) assert.Equal(t, int64(2), count)

View File

@@ -0,0 +1,22 @@
package buildinfo
var (
version = "dev"
commit = "unknown"
buildTime = "unknown"
)
// Version 返回构建注入的版本号。
func Version() string {
return version
}
// Commit 返回构建注入的 git commit。
func Commit() string {
return commit
}
// BuildTime 返回构建注入的构建时间。
func BuildTime() string {
return buildTime
}

View File

@@ -0,0 +1,17 @@
package buildinfo
import "testing"
func TestDefaults(t *testing.T) {
if Version() == "" {
t.Fatal("Version() 不应为空")
}
if Commit() == "" {
t.Fatal("Commit() 不应为空")
}
if BuildTime() == "" {
t.Fatal("BuildTime() 不应为空")
}
}

View File

@@ -1,6 +1,7 @@
package errors package errors
import ( import (
stderrors "errors"
"fmt" "fmt"
"net/http" "net/http"
) )
@@ -70,22 +71,11 @@ func AsAppError(err error) (*AppError, bool) {
if err == nil { if err == nil {
return nil, false return nil, false
} }
var appErr *AppError
if ok := is(err, &appErr); ok {
return appErr, true
}
return nil, false
}
func is(err error, target interface{}) bool { var appErr *AppError
// 简单的类型断言 if !stderrors.As(err, &appErr) {
if e, ok := err.(*AppError); ok { return nil, false
// 直接赋值
switch t := target.(type) {
case **AppError:
*t = e
return true
} }
}
return false return appErr, true
} }

View File

@@ -104,7 +104,8 @@ func TestPredefinedErrors(t *testing.T) {
func TestAsAppError(t *testing.T) { func TestAsAppError(t *testing.T) {
t.Run("nil输入", func(t *testing.T) { t.Run("nil输入", func(t *testing.T) {
_, ok := AsAppError(nil) appErr, ok := AsAppError(nil)
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
@@ -122,7 +123,8 @@ func TestAsAppError(t *testing.T) {
}) })
t.Run("非AppError类型", func(t *testing.T) { t.Run("非AppError类型", func(t *testing.T) {
_, ok := AsAppError(errors.New("普通错误")) appErr, ok := AsAppError(errors.New("普通错误"))
assert.Nil(t, appErr)
assert.False(t, ok) assert.False(t, ok)
}) })
} }

View File

@@ -1,13 +1,20 @@
package logger package logger
import "go.uber.org/zap" import (
"context"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type ctxKey struct{}
const requestIDKey = "request_id"
// WithRequestID 向 logger 添加 request_id 字段
func WithRequestID(logger *zap.Logger, requestID string) *zap.Logger { func WithRequestID(logger *zap.Logger, requestID string) *zap.Logger {
return logger.With(zap.String("request_id", requestID)) return logger.With(zap.String(requestIDKey, requestID))
} }
// WithContext 向 logger 添加多个自定义字段
func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger { func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger {
zapFields := make([]zap.Field, 0, len(fields)) zapFields := make([]zap.Field, 0, len(fields))
for k, v := range fields { for k, v := range fields {
@@ -15,3 +22,37 @@ func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger
} }
return logger.With(zapFields...) return logger.With(zapFields...)
} }
func RequestIDFromGinContext(c *gin.Context) zap.Field {
requestID, exists := c.Get("request_id")
if !exists {
return zap.Skip()
}
if id, ok := requestID.(string); ok {
return RequestID(id)
}
return zap.Skip()
}
func RequestIDFromContext(ctx context.Context) zap.Field {
requestID := ctx.Value(ctxKey{})
if requestID == nil {
return zap.Skip()
}
if id, ok := requestID.(string); ok {
return RequestID(id)
}
return zap.Skip()
}
func ContextWithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, ctxKey{}, requestID)
}
func LoggerFromContext(ctx context.Context, baseLogger *zap.Logger) *zap.Logger {
field := RequestIDFromContext(ctx)
if field == zap.Skip() {
return baseLogger
}
return baseLogger.With(field)
}

Some files were not shown because too many files have changed in this diff Show More