1
0

46 Commits

Author SHA1 Message Date
235efb0e62 chore: 版本升迁 v0.1.1 2026-05-05 04:39:31 +08:00
6b1af27ea2 fix: 移除 version-bump 的工作区干净检查 2026-05-05 04:39:21 +08:00
32f48777f3 feat: make version-bump 默认 BUMP=patch,无需显式传参 2026-05-05 04:32:38 +08:00
bc7a7c6e81 feat: 迁移 versionctl 为独立模块并新增 make version-bump 命令
- 将 backend/cmd/versionctl 和 backend/pkg/projectversion 迁移至独立 versionctl/ Go 模块
- 新增 bump 子命令支持 major/minor/patch 和指定版本号,含版本倒退防护
- 新增 make version-bump 编排完整升迁流程(bump + sync + check + commit + tag)
- 更新所有引用路径:根 Makefile、backend/Makefile、release.yml、.golangci.yml
- 新增 versionctl/.golangci.yml(精简配置)和 Makefile(lint/test/coverage)
- 根 Makefile lint/test 集成 versionctl 模块
- 同步 openspec specs:新增 version-bump spec,更新 release-pipeline spec
2026-05-05 04:18:10 +08:00
3cd0458c2c fix: 修复 golangci-lint 报告的 gosec/gocyclo/forbidigo 问题 2026-05-05 03:35:20 +08:00
8eea30ea11 feat: 统一品牌标识、关于页面三卡片布局与版本诊断功能
- 统一品牌为 Nex:侧边栏、托盘 tooltip、HTML 标题、favicon (PNG 替代 SVG)
- 重构关于页面为三卡片布局(品牌/版本/链接),版本状态 Tag 绝对定位右上角
- 新增 GET /api/version 后端接口,返回 version/commit/build_time
- 新增前端版本一致性诊断:匹配/不匹配/不可判断三种状态
- 同步 delta specs 到主 specs 并归档变更
2026-05-05 03:28:22 +08:00
9e33e570af fix: 降低请求生命周期日志级别 2026-05-05 01:54:53 +08:00
7653385838 fix: 加固发布流水线运行环境
修复 Windows 发布作业在 MSYS2 环境下无法访问 Go 工具链的问题。

为三平台发布增加工具链预检并升级 release workflow 运行时兼容性,减少版本检查噪音和 CI 告警。
2026-05-05 01:27:38 +08:00
2c401f7ae6 chore: streamline workspace make workflows
Clarify product-level server and desktop commands while moving backend-only maintenance tasks into backend/Makefile. This keeps root automation focused on core flows and aligns the main OpenSpec specs with the new command boundaries.
2026-04-28 17:44:23 +08:00
a9972360c2 feat: 增加版本化构建与发布流程
引入 VERSION 作为统一版本源,避免前端、后端、桌面打包和发布资产之间的版本漂移。
新增 tag 驱动的 Draft Release 流程与版本化资产命名,使本地演进和 GitHub 发布共享同一套约束。
2026-04-28 14:20:27 +08:00
b00fa4dcee chore: 调整开源许可证为 Apache 2.0
新增根目录 Apache 2.0 许可证文件,并同步更新仓库文档与前端包元数据声明。
2026-04-27 11:24:33 +08:00
92525b39c3 chore: 将 assets 图标文件迁移到 Git LFS 2026-04-27 10:34:38 +08:00
38a2555c7b fix: Anthropic 流式编码器补全 message_start/message_delta 必填字段
跨协议流式转换时,Anthropic 客户端 Zod 校验因 SSE 事件缺少必填字段报错。
由 Anthropic encoder 层(而非 OpenAI decoder 层)负责补全协议默认值,保持权责分离。

- encodeMessageStart 补全 type/content/stop_reason/stop_sequence,usage nil 时输出零值
- encodeMessageDelta usage nil 时输出零值
- 更新相关测试覆盖新增行为
2026-04-26 23:27:34 +08:00
9622d44aac fix: 完善转换代理行为 2026-04-26 21:48:17 +08:00
155244433f fix: 完善桌面应用图标打包
统一 macOS 图标命名为 icon.icns。\n补充 Linux hicolor 图标资源。\n修复 Windows make 构建兼容性并为 exe 嵌入图标资源。\n清理旧版图标说明与不再使用的 SVG 源文件。
2026-04-26 12:24:03 +08:00
2c043c6cf7 fix: 修正 conversion 代理路径和错误边界 2026-04-25 23:12:54 +08:00
f5c82b6980 chore: 合并 dev-test-config 到 master 2026-04-24 23:23:12 +08:00
9105a36097 feat: 将"关于"从系统托盘原生对话框迁移到前端页面
移除系统托盘右键菜单中的"关于"选项及各平台原生对话框实现,
在前端新增 /about 路由和关于页面展示品牌信息,侧边栏增加关于导航入口
2026-04-24 23:17:22 +08:00
f1ee646ca4 fix: 修复 TestSaveAndLoadConfig 测试隔离问题,使用临时目录替代真实用户配置 2026-04-24 22:59:26 +08:00
b9b487c591 chore: 统一项目编辑器配置,移至仓库根目录 2026-04-24 22:28:01 +08:00
4c62c071fb fix: 修复 macOS 桌面应用打包与元数据
将 macOS 桌面应用改为通用二进制并动态写入最低系统版本,避免 Intel Mac 无法启动。统一桌面应用名称与托盘展示,并补充测试确保相关行为稳定。
2026-04-24 19:35:51 +08:00
b2e9dd8b7f refactor: 合并 macOS 桌面打包流程
将 macOS .app 打包直接并入 desktop-build-mac,减少重复的桌面构建入口。\n\n同时移除未实现或已废弃的 desktop-package-* 命令和独立打包脚本,降低维护成本。
2026-04-24 19:02:21 +08:00
d143c5f3df fix: 补齐前端生成物忽略并消除构建告警
统一 Git、ESLint、Prettier 对测试和构建生成物的忽略规则,避免本地产物导致 frontend-build 失败。

补齐表单 effect 依赖,移除无关告警,让前端构建链路恢复稳定。
2026-04-24 18:53:53 +08:00
4eebdfb8db chore: 合并 dev-code-format-frontend 到 master 2026-04-24 18:21:27 +08:00
b517946585 chore: 合并 dev-code-backend-format 到 master 2026-04-24 18:20:12 +08:00
4ddae6be74 refactor: 优化提示词文档,增强智能合并与规范审查能力 2026-04-24 18:12:23 +08:00
195762ff97 fix: 修复后端配置加载测试失败
- 修复 viper SafeWriteConfig 与 SetConfigFile 不兼容问题
  - 将 SafeWriteConfig() 替换为 SafeWriteConfigAs(configPath)
  - 绕过 viper 的 configPaths 检查
- 调整 Makefile 测试命令分类
  - backend-test: 仅运行后端核心测试
  - backend-test-all: 运行全部后端测试(含 desktop)
  - desktop-test: 单独运行桌面应用测试
- 同步 config-management 和 test-coverage 规范
2026-04-24 14:06:03 +08:00
bcf5ca89e5 refactor: Makefile 前端命令自动安装依赖 2026-04-24 13:50:51 +08:00
365943e4c4 feat: 前端集成 Prettier 代码格式化 2026-04-24 13:40:53 +08:00
4c6b49099d feat: 配置 golangci-lint 静态分析并修复存量违规
- 新增 backend/.golangci.yml 配置 12 个 linter(forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、goimports、gocyclo)
- 新增 lefthook.yml 配置 pre-commit hook 自动运行 lint
- 修复存量代码违规:errors.Is/As 替换、zap.Error 替换、import 排序、errcheck 修复
- 更新 README 补充编码规范说明
- 归档 backend-code-lint 变更
2026-04-24 13:01:48 +08:00
4c78ab6cc8 chore: 新增 backend-code-lint 变更计划,更新 .gitignore 2026-04-24 00:19:56 +08:00
52007c9461 feat: 前端 ESLint 规则增强,自动检测 LLM 编码违规
- 启用 TanStack Query flat/recommended(7 条规则)
- 新增 no-console(允许 warn/error)、consistent-type-imports(inline 风格)、no-non-null-assertion 规则
- 新增自定义规则 no-hardcoded-color-in-style,检测 JSX style 中硬编码颜色值
- 将 ESLint 检查集成到 build 命令(tsc -b && eslint . && vite build)
- 修复现有代码中的 lint 违规(import 顺序、type import 风格、unused vars)
- 使用 @typescript-eslint/rule-tester 编写自定义规则集成测试
2026-04-23 22:47:32 +08:00
086dd1fed7 refactor: 重构智能合并提示词,优化交互体验并增加语义审查能力
- 三阶段模型:计划审批(1次) → 自动合并(仅异常中断) → 汇总收尾(1次),减少冗余介入
- 新增语义审查:合并后自动分析代码冗余/过时模式/未集成基础设施/风格不一致
- 完善安全锚点体系:全局锚点+分支级锚点,分级回退机制
- 修复多个逻辑缺陷:stash含未跟踪文件、分支匹配过宽、远端分支拉取重名、语义修复交叉污染等
- 明确提问工具选项和交互方式,消除歧义引导
2026-04-23 22:35:24 +08:00
1d7e839b49 Merge branch 'dev-frontend-encrypt' into master 2026-04-23 18:42:42 +08:00
fa7babf13b chore: 归档 fix-windows-desktop-packaging 变更 2026-04-23 18:40:23 +08:00
280099b89c refactor: 后端日志系统重构
- 新增模块化日志器(pkg/logger/module.go)
- 新增 GORM 日志适配器
- 统一日志入口,移除所有 zap.L() 全局 logger 调用
- 字段标准化
- 启动阶段使用结构化日志
- 更新所有相关测试
2026-04-23 18:37:51 +08:00
0a92a25451 feat: 前端生产构建添加代码混淆
- 集成 vite-plugin-javascript-obfuscator 插件
- 配置中等偏高强度混淆策略(变量名、字符串、对象键、数字)
- 仅生产构建时启用,不影响开发体验
- 仅混淆业务代码,排除第三方库
- 不生成 Source Map
- 新增 frontend-obfuscation 规范
2026-04-23 18:23:07 +08:00
8c075194e5 fix: 修复合并后代码质量问题
- 修正 Makefile 迁移目录路径(sqlite3 → sqlite)
- 统一 database.go 日志风格(log.Printf → zapLogger)
- 修复 config.go validator 标签大小写
- 修复 database_test.go 测试使用 nil logger
- 移除未使用的 log 导入
2026-04-23 16:58:01 +08:00
53e477d383 Merge branch 'dev-mysql-support' into master
- 新增 MySQL 数据库驱动支持,支持跨设备数据同步
- 新增 MySQL 专项测试能力(并发、约束、迁移)
- 重构迁移目录结构:migrations/sqlite 和 migrations/mysql
- 修复 statsRepo 并发竞态条件,使用 upsert 保证原子性
- Makefile 合并:保留完整命令体系 + 新增 MySQL 测试命令
2026-04-23 16:31:29 +08:00
1522c87c74 fix: 修复 statsRepo 并发竞态条件,使用 upsert 保证原子性
- 使用 GORM clause.OnConflict 替代事务包装
- Record 和 BatchUpdate 方法改用 upsert 模式
- 修复 UsageStats 的 GORM struct tag,确保 AutoMigrate 创建正确的 UNIQUE 约束
- 更新 usage-statistics spec 以反映 upsert 操作

MySQL 并发测试验证:10 并发调用 → request_count = 10
2026-04-23 15:54:56 +08:00
e0d05c9869 refactor: Makefile 命名规范化,新增顶层便捷命令
统一命名规范为 <namespace>-<action>[-<variant>] 格式:
- 重命名 desktop-mac/win/linux → desktop-build-mac/win/linux
- 重命名 backend-migrate-* → backend-db-*
- 重命名 frontend-build-desktop → desktop-prepare-frontend
- 重命名 embedfs-prepare → desktop-prepare-embedfs
- 重命名 package-macos → desktop-package-mac

新增顶层便捷命令:
- dev: 并行启动开发环境
- build: 构建所有产物
- test: 运行所有测试
- lint: 检查所有代码
- clean: 清理所有构建产物
2026-04-23 12:30:02 +08:00
5b401e29cb feat: 新增 MySQL 专项测试能力
- 新增 backend/tests/mysql/ 目录,包含 Docker Compose 配置和测试文件
- 新增 Makefile 命令: test-mysql, test-mysql-up, test-mysql-down, test-mysql-quick
- 使用 build tag 控制测试启用,默认不运行
- 测试覆盖: 迁移正确性、外键约束、UNIQUE 约束、并发写入
- 发现 statsRepo.Record 存在并发 bug(检查-然后-操作竞态条件)
2026-04-23 12:25:55 +08:00
65ac9f740a refactor: 桌面应用对话框代码拆分为平台专用文件
- 新增 dialog_windows.go、dialog_darwin.go、dialog_linux.go
- 使用 Go 构建标签实现条件编译
- 修复跨平台编译错误(syscall.NewLazyDLL 在 macOS/Linux 未定义)
- 实现 Linux 多工具降级策略(zenity → kdialog → notify-send → xmessage → stderr)
- 实现 macOS AppleScript 字符转义
- 更新 messagebox_test.go 构建标签
- 更新 desktop-app spec 新增 Linux 降级策略和 macOS 字符转义规范
2026-04-23 11:47:48 +08:00
58ebcaa299 refactor(scripts): 开发分支初始化脚本从 bash 重构为 Python,增强跨平台支持和用户体验 2026-04-23 10:43:59 +08:00
b3258e76df perf: 前端打包产物优化——路由级懒加载和 vendor 分包
- 使用 React.lazy() + Suspense 实现路由级代码分割
- 配置 manualChunks 将 react/tdesign/recharts 拆分为独立 vendor chunk
- 页面组件改为 export default 以支持动态导入
- 新增 bundle-optimization 规范,更新 frontend 导航规范
2026-04-23 00:26:54 +08:00
64dc66afa6 fix: Windows 桌面应用打包问题修复
- 删除通用 desktop target,重命名 platform targets 为简短形式 (desktop-mac/win/linux)
- 构建产物文件名统一为 nex-{os}-{arch}[.exe] 格式
- Windows 托盘图标使用 .ico 格式(运行时按平台选择)
- Windows 原生对话框使用 user32.MessageBoxW 替代 msg * 命令
- 更新 README.md 和 package-macos.sh 中的引用
- 添加单元测试覆盖 MessageBoxW 封装和图标选择逻辑
- 同步更新 desktop-app spec 规范文档
2026-04-22 23:20:39 +08:00
272 changed files with 13923 additions and 3912 deletions

7
.editorconfig Normal file
View File

@@ -0,0 +1,7 @@
root = true
[*]
end_of_line = lf
trim_trailing_whitespace = true
insert_final_newline = true
charset = utf-8

9
.gitattributes vendored Normal file
View File

@@ -0,0 +1,9 @@
* text=auto eol=lf
assets/*.png filter=lfs diff=lfs merge=lfs -text
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
assets/*.icns filter=lfs diff=lfs merge=lfs -text
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
assets/*.ico filter=lfs diff=lfs merge=lfs -text
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text
frontend/public/*.png filter=lfs diff=lfs merge=lfs -text
frontend/public/**/*.png filter=lfs diff=lfs merge=lfs -text

210
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,210 @@
name: Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: read
jobs:
prepare:
name: Prepare Release
runs-on: ubuntu-latest
permissions:
contents: read
outputs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Verify tag and VERSION
id: version
run: |
version=$(go run ./versionctl print)
go run ./versionctl verify-tag "${GITHUB_REF_NAME}"
printf 'version=%s\n' "$version" >> "$GITHUB_OUTPUT"
build-linux:
name: Build Linux Assets
needs: prepare
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Install Linux desktop build dependencies
run: |
sudo apt-get update
sudo apt-get install -y libayatana-appindicator3-dev libgtk-3-dev
- name: Preflight Linux release toolchain
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v gcc
gcc --version
command -v pkg-config
pkg-config --modversion ayatana-appindicator3-0.1
pkg-config --modversion gtk+-3.0
- name: Build Linux release assets
run: make release-assets-linux
- name: Upload Linux release assets
uses: actions/upload-artifact@v4
with:
name: release-linux
path: build/release/*
build-windows:
name: Build Windows Assets
needs: prepare
runs-on: windows-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Setup MSYS2 toolchain
uses: msys2/setup-msys2@v2
with:
msystem: MINGW64
path-type: inherit
update: true
install: >-
make
mingw-w64-x86_64-gcc
- name: Preflight Windows release toolchain
shell: msys2 {0}
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v make
make --version
command -v gcc
gcc --version
command -v windres
windres --version
if command -v powershell.exe >/dev/null 2>&1; then
powershell.exe -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
else
command -v powershell
powershell -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
fi
- name: Build Windows release assets
shell: msys2 {0}
run: make release-assets-windows
- name: Upload Windows release assets
uses: actions/upload-artifact@v4
with:
name: release-windows
path: build/release/*
build-macos:
name: Build macOS Assets
needs: prepare
runs-on: macos-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Preflight macOS release toolchain
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v ditto
xcrun --find lipo
xcrun --find vtool
- name: Build macOS release assets
run: make release-assets-macos
- name: Upload macOS release assets
uses: actions/upload-artifact@v4
with:
name: release-macos
path: build/release/*
draft-release:
name: Create Draft Release
needs: [prepare, build-linux, build-windows, build-macos]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download release assets
uses: actions/download-artifact@v4
with:
pattern: release-*
merge-multiple: true
path: dist
- name: Publish draft release
uses: softprops/action-gh-release@v2
with:
name: v${{ needs.prepare.outputs.version }}
tag_name: ${{ github.ref_name }}
draft: true
files: |
dist/*

5
.gitignore vendored
View File

@@ -401,13 +401,16 @@ cython_debug/
# Custom
.claude
.opencode
.codex
openspec/changes/archive
temp
.agents
skills-lock.json
.worktrees
!scripts/build/
backend/bin
# Embedfs generated
embedfs/assets/
embedfs/frontend-dist/
embedfs/frontend-dist/
backend/cmd/desktop/rsrc_windows_*.syso

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"files.eol": "\n"
}

184
LICENSE Normal file
View File

@@ -0,0 +1,184 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form,
made available under the License, as indicated by a copyright notice that is
included in or attached to the work (an example is provided in the Appendix
below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this
License, each Contributor hereby grants to You a perpetual, worldwide,
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
reproduce, prepare Derivative Works of, publicly display, publicly perform,
sublicense, and distribute the Work and such Derivative Works in Source or
Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License,
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section) patent
license to make, have made, use, offer to sell, sell, import, and otherwise
transfer the Work, where such license applies only to those patent claims
licensable by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s) with the Work
to which such Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a lawsuit)
alleging that the Work or a Contribution incorporated within the Work
constitutes direct or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate as of the date
such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or
Derivative Works thereof in any medium, with or without modifications, and in
Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of
this License; and
(b) You must cause any modified files to carry prominent notices stating that
You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You
distribute, all copyright, patent, trademark, and attribution notices from the
Source form of the Work, excluding those notices that do not pertain to any part
of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
any Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any
Contribution intentionally submitted for inclusion in the Work by You to the
Licensor shall be under the terms and conditions of this License, without any
additional terms or conditions. Notwithstanding the above, nothing herein shall
supersede or modify the terms of any separate license agreement you may have
executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names,
trademarks, service marks, or product names of the Licensor, except as required
for reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
writing, Licensor provides the Work (and each Contributor provides its
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied, including, without limitation, any warranties
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any risks
associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in
tort (including negligence), contract, or otherwise, unless required by
applicable law (such as deliberate and grossly negligent acts) or agreed to in
writing, shall any Contributor be liable to You for damages, including any
direct, indirect, special, incidental, or consequential damages of any character
arising as a result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill, work stoppage,
computer failure or malfunction, or any and all other commercial damages or
losses), even if such Contributor has been advised of the possibility of such
damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or
Derivative Works thereof, You may choose to offer, and charge a fee for,
acceptance of support, warranty, indemnity, or other liability obligations
and/or rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if You agree to
indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

318
Makefile
View File

@@ -1,119 +1,251 @@
.PHONY: all clean \
backend-build backend-run backend-test backend-test-unit backend-test-integration backend-test-coverage \
backend-lint backend-deps backend-generate \
backend-migrate-up backend-migrate-down backend-migrate-status backend-migrate-create \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint \
desktop desktop-darwin desktop-windows desktop-linux package-macos
.PHONY: \
lint test clean \
version-sync version-check version-bump \
server-run server-build server-lint server-test server-clean \
desktop-build-mac desktop-build-win desktop-build-linux \
desktop-lint desktop-test desktop-clean \
release-assets-linux release-assets-windows release-assets-macos \
_backend-lint _backend-test _backend-clean _backend-build \
_versionctl-lint _versionctl-test \
_frontend-install _frontend-build _frontend-check _frontend-test _frontend-dev _frontend-clean \
_desktop-test _desktop-clean _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource \
_server-run-backend _server-run-frontend
# Delay shell lookups until a target needs them, then cache the result for this make run.
lazy_shell = $(or $($(1)),$(eval $(1) := $(shell $(2)))$($(1)))
VERSION = $(call lazy_shell,_VERSION,go run ./versionctl print)
GIT_COMMIT ?= $(call lazy_shell,_GIT_COMMIT,git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
BUILD_TIME ?= $(call lazy_shell,_BUILD_TIME,date -u +"%Y-%m-%dT%H:%M:%SZ")
GO_LDFLAGS = -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
GO_LDFLAGS_WIN = $(GO_LDFLAGS) -H=windowsgui
RELEASE_DIR := build/release
SERVER_LINUX_ASSET = $(call lazy_shell,_SERVER_LINUX_ASSET,go run ./versionctl asset-name server linux amd64)
SERVER_WINDOWS_ASSET = $(call lazy_shell,_SERVER_WINDOWS_ASSET,go run ./versionctl asset-name server windows amd64)
SERVER_DARWIN_AMD64_ASSET = $(call lazy_shell,_SERVER_DARWIN_AMD64_ASSET,go run ./versionctl asset-name server darwin amd64)
SERVER_DARWIN_ARM64_ASSET = $(call lazy_shell,_SERVER_DARWIN_ARM64_ASSET,go run ./versionctl asset-name server darwin arm64)
DESKTOP_LINUX_ASSET = $(call lazy_shell,_DESKTOP_LINUX_ASSET,go run ./versionctl asset-name desktop linux)
DESKTOP_WINDOWS_ASSET = $(call lazy_shell,_DESKTOP_WINDOWS_ASSET,go run ./versionctl asset-name desktop windows)
DESKTOP_MACOS_ASSET = $(call lazy_shell,_DESKTOP_MACOS_ASSET,go run ./versionctl asset-name desktop macos)
# ============================================
# 后端
# 全局命令
# ============================================
all: backend-build
lint: _backend-lint _frontend-check _versionctl-lint
@printf 'Lint complete\n'
backend-build:
cd backend && go build -o bin/server ./cmd/server
test: _backend-test _frontend-test _desktop-test _versionctl-test
@printf 'All tests passed\n'
backend-run:
cd backend && go run ./cmd/server
backend-test:
cd backend && go test ./... -v
backend-test-unit:
cd backend && go test ./internal/... ./pkg/... -v
backend-test-integration:
cd backend && go test ./tests/... -v
backend-test-coverage:
cd backend && go test ./... -coverprofile=coverage.out
cd backend && go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: backend/coverage.html"
backend-lint:
cd backend && go tool golangci-lint run ./...
backend-deps:
cd backend && go mod tidy
backend-generate:
cd backend && go generate ./...
DB_DRIVER ?= sqlite3
DB_DSN ?= $(DB_PATH)
backend-migrate-up:
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" up
backend-migrate-down:
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" down
backend-migrate-status:
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" status
backend-migrate-create:
@read -p "Migration name: " name; \
cd backend && goose -dir migrations/sqlite create $$name sql; \
cd backend && goose -dir migrations/mysql create $$name sql
clean: _backend-clean _frontend-clean _desktop-clean
@printf 'Clean complete\n'
# ============================================
# 前端
# 版本管理
# ============================================
frontend-build:
cd frontend && bun install && bun run build
version-sync:
go run ./versionctl sync
frontend-dev:
cd frontend && bun dev
version-check:
go run ./versionctl check
frontend-test:
cd frontend && bun run test
frontend-test-watch:
cd frontend && bun run test:watch
frontend-test-coverage:
cd frontend && bun run test:coverage
frontend-test-e2e:
cd frontend && bun run test:e2e
frontend-lint:
cd frontend && bun run lint
version-bump: BUMP ?= patch
version-bump:
$(eval _BUMP_ARG := $(if $(SET_VERSION),$(SET_VERSION),$(BUMP)))
$(eval _NEW_VERSION := $(shell go run ./versionctl bump $(_BUMP_ARG)))
git add VERSION frontend/
git commit -m "chore: 版本升迁 v$(_NEW_VERSION)"
git tag "v$(_NEW_VERSION)"
@printf '版本升迁完成: v%s\n' "$(_NEW_VERSION)"
# ============================================
# 桌面应用
# Server 模式
# ============================================
desktop: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 go build -o ../build/nex ./cmd/desktop
server-run:
@$(MAKE) -j2 _server-run-backend _server-run-frontend
frontend-build-desktop:
cd frontend && cp .env.desktop .env.production.local && bun install && bun run build && rm -f .env.production.local
server-build: version-check _backend-build _frontend-build
@printf 'Server build complete\n'
embedfs-prepare:
server-lint: _backend-lint _frontend-check
@printf 'Server lint complete\n'
server-test: _backend-test _frontend-test
@printf 'Server tests passed\n'
server-clean: _backend-clean _frontend-clean
@printf 'Server artifacts cleaned\n'
_server-run-backend:
@$(MAKE) -C backend run
_server-run-frontend: _frontend-install
cd frontend && bun run dev
# ============================================
# Desktop 模式
# ============================================
desktop-build-mac: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
@printf 'Building macOS desktop...\n'
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-arm64 ./cmd/desktop
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-amd64 ./cmd/desktop
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
@printf 'Packaging macOS app bundle...\n'
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
@if [ -f assets/icon.icns ]; then \
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
else \
printf 'Missing assets/icon.icns\n'; \
fi
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
if [ -z "$$MIN_MACOS_VERSION" ]; then \
printf 'Unable to read macOS minimum version\n'; \
exit 1; \
fi; \
go run ./versionctl macos-plist "$$MIN_MACOS_VERSION" > build/Nex.app/Contents/Info.plist
chmod +x build/Nex.app/Contents/MacOS/nex
@printf 'macOS desktop build complete\n'
desktop-build-win: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource
@printf 'Building Windows desktop...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "New-Item -ItemType Directory -Path 'build' -Force | Out-Null"
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
else
mkdir -p build
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
endif
@printf 'Windows desktop build complete\n'
desktop-build-linux: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
@printf 'Building Linux desktop...\n'
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-linux-amd64 ./cmd/desktop
@printf 'Linux desktop build complete\n'
desktop-lint: _backend-lint _frontend-check
@printf 'Desktop lint complete\n'
desktop-test: _desktop-test
@printf 'Desktop tests passed\n'
desktop-clean: _desktop-clean
@printf 'Desktop artifacts cleaned\n'
_desktop-test:
cd backend && go test ./cmd/desktop/... -v
_desktop-clean:
rm -rf build/ embedfs/assets embedfs/frontend-dist backend/cmd/desktop/rsrc_windows_amd64.syso
_desktop-prepare-frontend: _frontend-install
@printf 'Preparing frontend for desktop...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Copy-Item -LiteralPath 'frontend/.env.desktop' -Destination 'frontend/.env.production.local' -Force"
cd frontend && bun run build
powershell -NoProfile -Command "Remove-Item -LiteralPath 'frontend/.env.production.local' -Force -ErrorAction SilentlyContinue"
else
cd frontend && cp .env.desktop .env.production.local
cd frontend && bun run build
rm -f frontend/.env.production.local
endif
_desktop-prepare-embedfs:
@printf 'Preparing embedded filesystem...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Remove-Item -LiteralPath 'embedfs/assets' -Recurse -Force -ErrorAction SilentlyContinue; Remove-Item -LiteralPath 'embedfs/frontend-dist' -Recurse -Force -ErrorAction SilentlyContinue; Copy-Item -LiteralPath 'assets' -Destination 'embedfs/assets' -Recurse; Copy-Item -LiteralPath 'frontend/dist' -Destination 'embedfs/frontend-dist' -Recurse"
else
rm -rf embedfs/assets embedfs/frontend-dist
cp -r assets embedfs/assets
cp -r frontend/dist embedfs/frontend-dist
endif
desktop-darwin: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-darwin-arm64 ./cmd/desktop
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-darwin-amd64 ./cmd/desktop
desktop-windows: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-windows-amd64.exe ./cmd/desktop
desktop-linux: frontend-build-desktop embedfs-prepare
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop
package-macos:
./scripts/build/package-macos.sh
_desktop-prepare-windows-resource:
@printf 'Preparing Windows executable icon...\n'
ifeq ($(OS),Windows_NT)
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso
else
@if command -v x86_64-w64-mingw32-windres >/dev/null 2>&1; then \
cd backend/cmd/desktop && x86_64-w64-mingw32-windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
elif command -v windres >/dev/null 2>&1; then \
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
else \
printf 'Missing windres for Windows icon resource generation\n'; \
exit 1; \
fi
endif
# ============================================
# 清理
# 发布资产
# ============================================
clean:
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
rm -rf build/
release-assets-linux: version-check desktop-build-linux
rm -rf "$(RELEASE_DIR)"
mkdir -p "$(RELEASE_DIR)"
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-linux-amd64 ./cmd/server
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_LINUX_ASSET)" nex-server-linux-amd64
tar -C build -czf "$(RELEASE_DIR)/$(DESKTOP_LINUX_ASSET)" nex-linux-amd64
release-assets-windows: version-check desktop-build-win
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Remove-Item -LiteralPath '$(RELEASE_DIR)' -Recurse -Force -ErrorAction SilentlyContinue; New-Item -ItemType Directory -Path '$(RELEASE_DIR)' -Force | Out-Null"
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-server-win-amd64.exe ./cmd/server
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-server-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(SERVER_WINDOWS_ASSET)' -Force"
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(DESKTOP_WINDOWS_ASSET)' -Force"
else
@printf 'release-assets-windows requires Windows\n'
@exit 1
endif
release-assets-macos: version-check desktop-build-mac
rm -rf "$(RELEASE_DIR)"
mkdir -p "$(RELEASE_DIR)"
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-amd64 ./cmd/server
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-arm64 ./cmd/server
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_AMD64_ASSET)" nex-server-darwin-amd64
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_ARM64_ASSET)" nex-server-darwin-arm64
ditto -c -k --keepParent build/Nex.app "$(RELEASE_DIR)/$(DESKTOP_MACOS_ASSET)"
# ============================================
# 共享 helper targets
# ============================================
_backend-build:
@$(MAKE) -C backend build
_backend-lint:
@$(MAKE) -C backend lint
_backend-test:
@$(MAKE) -C backend test
_backend-clean:
@$(MAKE) -C backend clean
_versionctl-lint:
@$(MAKE) -C versionctl lint
_versionctl-test:
@$(MAKE) -C versionctl test
_frontend-install:
cd frontend && bun install
_frontend-build: _frontend-install
cd frontend && bun run build
_frontend-check: _frontend-install
cd frontend && bun run check
_frontend-test: _frontend-install
cd frontend && bun run test
_frontend-dev: _frontend-install
cd frontend && bun run dev
_frontend-clean:
rm -rf frontend/dist frontend/.next frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo

184
README.md
View File

@@ -36,13 +36,9 @@ nex/
├── assets/ # 应用资源
│ ├── icon.png # 托盘图标
│ ├── AppIcon.icns # macOS 应用图标
│ ├── icon.icns # macOS 应用图标
│ └── icon.ico # Windows 应用图标
├── scripts/ # 构建脚本
│ └── build/
│ └── package-macos.sh # macOS .app 打包脚本
└── README.md # 本文件
```
@@ -51,7 +47,7 @@ nex/
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
- **跨协议转换**Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
- **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段
- **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
- **Function Calling**支持工具调用Tools
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
@@ -67,11 +63,25 @@ nex/
- **HTTP 框架**: Gin
- **ORM**: GORM
- **数据库**: SQLite / MySQL
- **日志**: zap + lumberjack结构化日志 + 日志轮转)
- **日志**: zap + lumberjack结构化日志 + 日志轮转 + 模块标识
- **配置**: Viper + pflag多层配置CLI > 环境变量 > 配置文件 > 默认值)
- **验证**: go-playground/validator/v10
- **迁移**: goose
#### 日志模块标识规范
每个模块通过依赖注入获取带模块标识的 logger日志输出格式为 `[module.name]`
```
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
```
模块命名规范:
- 单一职责包:`database``config`
- 多实体包:`handler.proxy``service.provider`
- 子包:`handler.middleware`
### 前端
- **运行时**: Bun
- **构建工具**: Vite
@@ -81,7 +91,7 @@ nex/
- **图表库**: Recharts
- **路由**: React Router v7
- **数据获取**: TanStack Query v5
- **样式**: SCSS Modules
- **样式**: TDesign 组件 props 优先TDesign tokens 次之SCSS 作为兜底补充
- **测试**: Vitest + React Testing Library + Playwright
## 快速开始
@@ -91,22 +101,18 @@ nex/
**构建桌面应用**
```bash
# 当前平台
make desktop
# macOS (arm64 + amd64)
make desktop-darwin
make package-macos # 打包为 .app
# macOS (arm64 + amd64并打包为 .app)
make desktop-build-mac
# Windows
make desktop-windows
make desktop-build-win
# Linux
make desktop-linux
make desktop-build-linux
```
**使用桌面应用**
- 双击启动应用macOS: Nex.appWindows: nex.exeLinux: nex
- 双击启动应用macOS: Nex.appWindows: nex-win-amd64.exeLinux: nex-linux-amd64
- 系统托盘图标出现,浏览器自动打开管理界面
- 点击托盘图标显示菜单,可打开管理界面或退出
- 关闭浏览器后服务继续运行,可通过托盘重新打开
@@ -123,50 +129,54 @@ make desktop-linux
- Xfce: 需要 libappindicator
- 其他支持 StatusNotifierItem 规范的环境
### CLI 模式
#### 后端
### Server 模式(前后端分离)
```bash
cd backend
go mod download
go run cmd/server/main.go
make server-run
```
后端服务将在 `http://localhost:9826` 启动。首次启动会自动:
`make server-run` 会并行启动:
- 后端服务:`http://localhost:9826`
- 前端开发服务器:`http://localhost:5173`
前端请求会继续通过 Vite proxy 转发到后端。后端首次启动会自动:
- 创建配置文件 `~/.nex/config.yaml`
- 初始化数据库 `~/.nex/config.db`
- 运行数据库迁移
- 创建日志目录 `~/.nex/log/`
### 前端
**构建 server 模式产物**
```bash
cd frontend
bun install
bun dev
make server-build
```
前端开发服务器将在 `http://localhost:5173` 启动API 请求通过 Vite proxy 转发到后端。
## API 接口
### 代理接口(对外部应用)
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
**OpenAI 协议**`protocol=openai`
- `POST /openai/chat/completions` - 对话补全
- `GET /openai/models` - 模型列表(本地数据库聚合)
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
- `POST /openai/embeddings` - 嵌入
- `POST /openai/rerank` - 重排序
- `POST /openai/v1/chat/completions` - 对话补全
- `GET /openai/v1/models` - 模型列表(本地数据库聚合)
- `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
- `POST /openai/v1/embeddings` - 嵌入
- `POST /openai/v1/rerank` - 重排序
**Anthropic 协议**`protocol=anthropic`
- `POST /anthropic/v1/messages` - 消息对话
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions``/v1/models``/v1/embeddings``/v1/rerank`,并在构建上游 URL 时去掉 `/v1`Anthropic adapter 接收 `/v1/messages``/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON``MODEL_NOT_FOUND``CONVERSION_FAILED``UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
### 管理接口(对前端)
#### 供应商管理
@@ -189,6 +199,9 @@ bun dev
查询参数支持:`provider_id``model_name``start_date``end_date``group_by`
#### 版本信息
- `GET /api/version` - 获取后端构建版本信息(`version``commit``build_time`),用于前端 About 页面诊断前后端版本一致性
## 配置
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
@@ -262,25 +275,100 @@ export NEX_DATABASE_DBNAME=nex
## 测试
```bash
make backend-test # 后端测试
make backend-test-coverage # 后端覆盖率
make frontend-test # 前端测试
make frontend-test-e2e # 前端 E2E 测试
# 全局默认测试(不含 MySQL 和前端 E2E
make test
# 产品级测试
make server-test
make desktop-test
```
backend 分类测试、MySQL 专项测试和前端 E2E 测试请分别查看 `backend/README.md``frontend/README.md`
## 开发
```bash
make backend-build # 构建后端
make backend-run # 运行后端
make backend-lint # 后端代码检查
make backend-migrate-up # 数据库迁移
# 首次克隆后安装 Git hooks
lefthook install
make frontend-build # 构建前端
make frontend-dev # 前端开发模式
make frontend-lint # 前端代码检查
# 全局命令
make lint # 前后端共享检查
make test # 默认全量测试(不含 MySQL/E2E
make clean # 清理所有构建产物和测试报告
# server 模式
make server-run # 并行启动后端和前端开发服务
make server-build # 构建 backend/bin/server 和 frontend/dist
make server-lint # server 模式检查
make server-test # server 模式测试
make server-clean # 清理 server 模式产物
# desktop 模式
make desktop-build-mac # 构建 macOS 桌面应用
make desktop-build-win # 构建 Windows 桌面应用
make desktop-build-linux # 构建 Linux 桌面应用
make desktop-lint # desktop 模式检查
make desktop-test # desktop 专属测试
make desktop-clean # 清理 desktop 产物
```
## 版本与发布
### 统一版本源
- 仓库根目录 `VERSION` 是全仓唯一版本源,格式固定为 `x.y.z`
- `frontend/package.json` 和前端 `.env.*` 中的 `VITE_APP_VERSION` 由仓库工具同步,不能手工漂移
### 本地版本演进
```bash
# 递增版本(自动 sync + check + commit + tag
make version-bump BUMP=minor
# 或指定具体版本号
make version-bump SET_VERSION=1.0.0
# 推送到远程
git push --follow-tags
```
手动同步和校验:
```bash
make version-sync
make version-check
```
### 本地生成发布资产
```bash
# Linux: server + desktop
make release-assets-linux
# Windows: server + desktop需在 Windows 环境执行)
make release-assets-windows
# macOS: darwin-amd64 server、darwin-arm64 server、desktop universal
make release-assets-macos
```
生成的版本化发布资产位于 `build/release/`
### GitHub Draft Release
- 推送 `vX.Y.Z` tag 后,`.github/workflows/release.yml` 会自动执行发布流水线
- 三个平台 job 会在正式构建前先检查 `go``bun` 和各自的平台打包工具链,缺失时快速失败并在日志中输出诊断信息
- Windows 发布 job 在 `MSYS2 / MINGW64` shell 中执行,并继承 `setup-go` / `setup-bun` 准备好的工具链路径
- 流水线会先校验 tag 与 `VERSION` 一致,再构建以下资产并上传到 GitHub Draft Release
- Linux server
- Windows server
- darwin-amd64 server
- darwin-arm64 server
- Linux desktop
- Windows desktop
- macOS desktop universal
- Release 默认以 Draft 形式创建,需人工检查后再公开发布
## 开发规范
详见各子项目的 README.md
@@ -289,4 +377,4 @@ make frontend-lint # 前端代码检查
## 许可证
MIT
Apache License 2.0

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.1.1

Binary file not shown.

View File

@@ -1,64 +0,0 @@
# Assets
应用资源文件目录。
## 文件说明
| 文件 | 用途 | 尺寸 | 格式 |
|------|------|------|------|
| `icon.svg` | 源图标 | 64x64 | SVG |
| `icon.png` | 托盘图标 | 64x64 | PNG |
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
## 替换图标
### 1. 准备图标
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
### 2. 生成各平台图标
**托盘图标 (PNG)**
```bash
magick your-icon.svg -resize 64x64 icon.png
```
**macOS 应用图标 (ICNS)**
```bash
mkdir icon.iconset
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
iconutil -c icns icon.iconset -o AppIcon.icns
rm -rf icon.iconset
```
**Windows 应用图标 (ICO)**
```bash
magick your-icon.svg -resize 256x256 icon.ico
```
### 3. 替换文件
将生成的文件放入此目录,然后重新构建桌面应用:
```bash
./scripts/build/build-darwin-arm64.sh
```
## macOS Template 图标
macOS 支持 Template 图标,自动适配深浅色模式:
- 使用黑色 + 透明设计
- 文件名以 `Template` 结尾(如 `iconTemplate.png`
- 黑色在深色模式下自动变为白色
## 设计建议
- 托盘图标应简洁,在小尺寸下清晰可辨
- 避免过多细节和文字
- 使用高对比度颜色
- macOS 建议使用 Template 图标风格

BIN
assets/icon.icns LFS Normal file

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 264 KiB

After

Width:  |  Height:  |  Size: 130 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

After

Width:  |  Height:  |  Size: 131 B

View File

@@ -1,13 +0,0 @@
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
<circle cx="32" cy="32" r="6" fill="white"/>
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
<circle cx="20" cy="20" r="3" fill="white"/>
<circle cx="44" cy="20" r="3" fill="white"/>
<circle cx="20" cy="44" r="3" fill="white"/>
<circle cx="44" cy="44" r="3" fill="white"/>
</svg>

Before

Width:  |  Height:  |  Size: 779 B

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

97
backend/Makefile Normal file
View File

@@ -0,0 +1,97 @@
.PHONY: \
build run \
test test-unit test-integration test-coverage \
lint clean \
migrate-up migrate-down migrate-status migrate-create \
mysql-up mysql-down mysql-test mysql-test-quick
VERSION := $(shell go run ../versionctl print)
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
DB_DRIVER ?= sqlite3
DB_DSN ?= $(HOME)/.nex/config.db
ifeq ($(DB_DRIVER),mysql)
GOOSE_DIR := migrations/mysql
GOOSE_DRIVER := mysql
else ifeq ($(DB_DRIVER),sqlite3)
GOOSE_DIR := migrations/sqlite
GOOSE_DRIVER := sqlite3
else
$(error unsupported DB_DRIVER '$(DB_DRIVER)', use sqlite3 or mysql)
endif
build:
go build -ldflags "$(GO_LDFLAGS)" -o bin/server ./cmd/server
run:
go run -ldflags "$(GO_LDFLAGS)" ./cmd/server
test:
go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
test-unit:
go test ./internal/... ./pkg/... -v
test-integration:
go test ./tests/... -v
test-coverage:
go test ./... -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
@printf 'Coverage report generated: backend/coverage.html\n'
lint:
go tool golangci-lint run ./...
clean:
rm -rf bin/ coverage.out coverage.html
migrate-up:
@printf 'Running database migration up...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" up
migrate-down:
@printf 'Running database migration down...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" down
migrate-status:
@printf 'Checking database migration status...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" status
migrate-create:
@printf 'Migration name: '; \
read name; \
goose -dir migrations/sqlite create $$name sql; \
goose -dir migrations/mysql create $$name sql
mysql-up:
@printf 'Starting MySQL test container...\n'
cd tests/mysql && docker-compose up -d
@printf 'Waiting for MySQL to be ready...\n'
@for i in $$(seq 1 30); do \
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
printf 'MySQL is ready\n'; \
exit 0; \
fi; \
printf 'Waiting... (%s/30)\n' $$i; \
sleep 1; \
done; \
printf 'MySQL failed to start\n'; \
exit 1
mysql-down:
@printf 'Stopping MySQL test container...\n'
cd tests/mysql && docker-compose down -v
mysql-test:
@set -e; \
$(MAKE) mysql-up; \
trap '$(MAKE) mysql-down' EXIT; \
go test -tags=mysql ./tests/mysql/... -v -count=1
mysql-test-quick:
@printf 'Running MySQL tests without container management...\n'
go test -tags=mysql ./tests/mysql/... -v -count=1

View File

@@ -4,21 +4,67 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
## 功能特性
- 支持 OpenAI 协议(`/openai/v1/...`
- 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`
- 支持 Anthropic 协议(`/anthropic/v1/...`
- 支持 Hub-and-Spoke 跨协议双向转换OpenAI ↔ Anthropic
- 同协议透传(零语义损失、零序列化开销
- 同协议透传(跳过 Canonical 全量转换,保持协议语义
- 支持流式响应SSE
- 支持 Function Calling / Tools
- 支持 Thinking / Reasoning
- 支持扩展层接口Models、Embeddings、Rerank
- 多供应商配置和路由
- 用量统计
- 结构化日志zap + lumberjack
- 结构化日志zap + lumberjack + 模块标识
- YAML 配置管理
- 请求验证
- 中间件支持(请求 ID、日志、恢复、CORS
## 日志规范
### 模块标识
每个模块通过依赖注入获取带模块标识的 logger
```go
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
return &ProxyHandler{
logger: pkglogger.WithModule(logger, "handler.proxy"),
}
}
```
输出格式:
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
### 模块命名规范
| 模块 | 命名 |
|------|------|
| ProxyHandler | `handler.proxy` |
| ProviderHandler | `handler.provider` |
| Provider Client | `provider.client` |
| ConversionEngine | `conversion.engine` |
| RoutingCache | `service.routing_cache` |
| StatsBuffer | `service.stats_buffer` |
| Database | `database` |
### 标准字段
使用 `pkg/logger/field.go` 中定义的字段构造函数:
```go
logger.Debug("请求开始",
pkglogger.Method("POST"),
pkglogger.Path("/v1/chat"),
pkglogger.RequestID("xxx"),
)
```
### GORM 日志
GORM 日志自动桥接到 zapSQL 查询映射到 Debug 级别。
## 技术栈
- **语言**: Go 1.26+
@@ -105,9 +151,13 @@ backend/
│ │ ├── errors.go
│ │ └── wrap.go
│ ├── logger/ # 日志系统
│ │ ├── logger.go
│ │ ├── rotate.go
│ │ ── context.go
│ │ ├── logger.go # 核心初始化
│ │ ├── field.go # 标准字段定义
│ │ ── module.go # 模块日志器
│ │ ├── context.go # Context 辅助函数
│ │ ├── gorm.go # GORM 适配器
│ │ ├── minimal.go # 最小化 logger
│ │ └── rotate.go # 日志轮转
│ ├── modelid/ # 统一模型 ID 工具包
│ │ ├── model_id.go
│ │ └── model_id_test.go
@@ -170,7 +220,7 @@ OpenAI Response ← Canonical Response ← Anthropic Response
### Smart Passthrough 机制
同协议请求走 Smart Passthrough 路径,**零序列化开销**
同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换
```
1. 检测 clientProtocol == providerProtocol
@@ -179,12 +229,14 @@ OpenAI Response ← Canonical Response ← Anthropic Response
4. 响应中仅改写 model 字段upstream_model_name → unified_id
```
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
### 流式转换器层次
```
StreamConverter (接口)
├── PassthroughStreamConverter # 直接透传,无任何处理
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model
├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
└── CanonicalStreamConverter # 跨协议完整转换decode → encode
```
@@ -251,6 +303,7 @@ StreamConverter (接口)
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
| `ENCODING_FAILURE` | 编码失败 |
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings |
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
### AppError 预定义错误
@@ -384,24 +437,37 @@ docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
## 测试
```bash
# 运行所有测试
# 运行 backend 默认测试
make test
# 分类测试
make test-unit
make test-integration
# 生成覆盖率报告
make test-coverage
# MySQL 专项测试
make mysql-up
make mysql-down
make mysql-test
make mysql-test-quick
```
## 数据库迁移
```bash
# 使用 Makefile
make migrate-up DB_PATH=~/.nex/config.db
make migrate-down DB_PATH=~/.nex/config.db
make migrate-status DB_PATH=~/.nex/config.db
make migrate-up DB_DSN=~/.nex/config.db
make migrate-down DB_DSN=~/.nex/config.db
make migrate-status DB_DSN=~/.nex/config.db
# 创建新迁移
make migrate-create
# MySQL 迁移
make migrate-up DB_DRIVER=mysql DB_DSN='user:pass@tcp(localhost:3306)/nex'
# 或直接使用 goose
goose -dir migrations sqlite3 ~/.nex/config.db up
```
@@ -410,15 +476,15 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
### 代理接口
使用 `/{protocol}/v1/{path}` URL 前缀路由
使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath由对应 adapter 识别和组合上游 URL。
#### OpenAI 协议
```
POST /openai/chat/completions
GET /openai/models
POST /openai/embeddings
POST /openai/rerank
POST /openai/v1/chat/completions
GET /openai/v1/models
POST /openai/v1/embeddings
POST /openai/v1/rerank
```
#### Anthropic 协议
@@ -428,10 +494,20 @@ POST /anthropic/v1/messages
GET /anthropic/v1/models
```
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough跳过 Canonical 全量转换
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
**base_url 约定**
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions`OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
**流式透传边界**:同协议无响应 model 改写时 raw passthrough保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`
### 管理接口
#### 供应商管理
@@ -459,7 +535,7 @@ GET /anthropic/v1/models
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
**对外 URL 格式**
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions``/openai/models``/openai/embeddings`
- OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions``/openai/v1/models``/openai/v1/embeddings`
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages``/anthropic/v1/models`
#### 模型管理
@@ -501,6 +577,20 @@ GET /anthropic/v1/models
查询参数:`provider_id``model_name``start_date`YYYY-MM-DD`end_date``group_by`provider/model/date
#### 版本信息
- `GET /api/version` - 获取后端构建版本信息
响应字段来源于构建阶段注入的 `buildinfo` 元数据:
```json
{
"version": "0.1.0",
"commit": "abc1234",
"build_time": "2026-05-05T00:00:00Z"
}
```
#### 健康检查
- `GET /health` - 返回 `{"status": "ok"}`
@@ -508,9 +598,12 @@ GET /anthropic/v1/models
## 开发
```bash
make build # 构建
make lint # 代码检查
make deps # 整理依赖
make build # 构建 backend/bin/server
make run # 运行后端服务
make lint # 代码检查
make clean # 清理 backend 构建产物
go mod tidy # 整理依赖
go generate ./... # 刷新 mock 等生成代码
```
环境要求Go 1.26 或更高版本
@@ -559,6 +652,7 @@ err := v.Validate(myStruct)
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -0,0 +1,25 @@
//go:build darwin
package main
import (
"fmt"
"os/exec"
"strings"
"go.uber.org/zap"
)
func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title))
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
}
func escapeAppleScript(s string) string {
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "\"", "\\\"")
return s
}

View File

@@ -0,0 +1,67 @@
//go:build linux
package main
import (
"fmt"
"os/exec"
"sync"
)
type dialogToolType int
const (
toolNone dialogToolType = iota
toolZenity
toolKdialog
toolNotifySend
toolXmessage
)
var (
dialogTool dialogToolType
dialogToolOnce sync.Once
)
func init() {
dialogToolOnce.Do(detectDialogTool)
}
func detectDialogTool() {
tools := []struct {
name string
typ dialogToolType
}{
{"zenity", toolZenity},
{"kdialog", toolKdialog},
{"notify-send", toolNotifySend},
{"xmessage", toolXmessage},
}
for _, tool := range tools {
if _, err := exec.LookPath(tool.name); err == nil {
dialogTool = tool.typ
return
}
}
dialogTool = toolNone
}
func showError(title, message string) {
switch dialogTool {
case toolZenity:
exec.Command("zenity", "--error",
fmt.Sprintf("--title=%s", title),
fmt.Sprintf("--text=%s", message)).Run()
case toolKdialog:
exec.Command("kdialog", "--error", message, "--title", title).Run()
case toolNotifySend:
exec.Command("notify-send", "-u", "critical", title, message).Run()
case toolXmessage:
exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run()
default:
dialogLogger().Error("无法显示错误对话框")
}
}

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -0,0 +1,62 @@
//go:build windows
package main
import (
"errors"
"fmt"
"syscall"
"unsafe"
"go.uber.org/zap"
)
const (
mbIconError = 0x10
mbIconInformation = 0x40
)
var (
user32 = syscall.NewLazyDLL("user32.dll")
procMessageBoxW = user32.NewProc("MessageBoxW")
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
return ret, err
}
)
func showError(title, message string) {
if err := messageBox(title, message, mbIconError); err != nil {
if zapLogger != nil {
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
}
}
}
func messageBox(title, message string, flags uint) error {
titlePtr, err := syscall.UTF16PtrFromString(title)
if err != nil {
return err
}
messagePtr, err := syscall.UTF16PtrFromString(message)
if err != nil {
return err
}
ret, callErr := callMessageBoxW(
0,
uintptr(unsafe.Pointer(messagePtr)),
uintptr(unsafe.Pointer(titlePtr)),
uintptr(flags),
)
if ret != 0 {
return nil
}
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
return callErr
}
return fmt.Errorf("MessageBoxW 调用失败")
}

View File

@@ -0,0 +1,33 @@
package main
import (
"runtime"
"testing"
"nex/embedfs"
)
func TestIconSelection_Windows(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("图标格式选择测试仅在 Windows 上运行")
}
if err := testIconLoad("assets/icon.ico"); err != nil {
t.Fatalf("Windows 应加载 .ico 文件: %v", err)
}
}
func TestIconSelection_NonWindows(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("图标格式选择测试在非 Windows 平台运行")
}
if err := testIconLoad("assets/icon.png"); err != nil {
t.Fatalf("非 Windows 平台应加载 .png 文件: %v", err)
}
}
func testIconLoad(path string) error {
_, err := embedfs.Assets.ReadFile(path)
return err
}

View File

@@ -0,0 +1 @@
1 ICON "../../../assets/icon.ico"

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
@@ -14,10 +13,7 @@ import (
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/getlantern/systray"
"github.com/gofrs/flock"
"go.uber.org/zap"
"nex/embedfs"
"nex/backend/internal/config"
"nex/backend/internal/conversion"
@@ -29,9 +25,14 @@ import (
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger"
"nex/backend/pkg/buildinfo"
"nex/embedfs"
"github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
var (
@@ -44,25 +45,32 @@ var (
func main() {
port := 9826
minimalLogger := pkgLogger.NewMinimal()
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
if err := singleLock.Lock(); err != nil {
showError("Nex Gateway", "已有 Nex 实例运行")
minimalLogger.Error("已有 Nex 实例运行")
showError(appName, "已有 Nex 实例运行")
os.Exit(1)
}
defer singleLock.Unlock()
defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil {
showError("Nex Gateway", err.Error())
os.Exit(1)
minimalLogger.Error("端口不可用", zap.Error(err))
showError(appName, err.Error())
return
}
cfg, err := config.LoadConfig()
if err != nil {
showError("Nex Gateway", fmt.Sprintf("加载配置失败: %v", err))
os.Exit(1)
minimalLogger.Fatal("加载配置失败", zap.Error(err))
}
zapLogger, err = pkgLogger.New(pkgLogger.Config{
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
Level: cfg.Log.Level,
Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize,
@@ -71,15 +79,19 @@ func main() {
Compress: cfg.Log.Compress,
})
if err != nil {
showError("Nex Gateway", fmt.Sprintf("初始化日志失败: %v", err))
os.Exit(1)
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
}
defer zapLogger.Sync()
defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger)
db, err := database.Init(&cfg.Database, zapLogger)
if err != nil {
showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err))
os.Exit(1)
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
}
defer database.Close(db)
@@ -105,19 +117,20 @@ func main() {
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
}
if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
}
engine := conversion.NewConversionEngine(registry, zapLogger)
providerClient := provider.NewClient()
providerClient := provider.NewClient(zapLogger)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
versionHandler := handler.NewVersionHandler()
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -127,7 +140,7 @@ func main() {
r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
setupStaticFiles(r)
server = &http.Server{
@@ -140,24 +153,30 @@ func main() {
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
zapLogger.Info("AI Gateway 启动",
zap.String("addr", server.Addr),
zap.String("version", buildinfo.Version()),
zap.String("commit", buildinfo.Commit()),
zap.String("build_time", buildinfo.BuildTime()))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
zapLogger.Fatal("服务器启动失败", zap.Error(err))
}
}()
go func() {
time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error()))
zapLogger.Warn("无法打开浏览器", zap.Error(err))
}
}()
setupSystray(port)
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
r.Any("/v1/*path", proxyHandler.HandleProxy)
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
r.GET("/api/version", versionHandler.GetVersion)
providers := r.Group("/api/providers")
{
@@ -188,12 +207,26 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
})
}
func setupStaticFiles(r *gin.Engine) {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error()))
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
next(c)
}
}
func setupStaticFiles(r *gin.Engine) {
distFS, err := frontendDistFS()
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
}
setupStaticFilesWithFS(r, distFS)
}
func frontendDistFS() (fs.FS, error) {
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
}
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
@@ -226,20 +259,23 @@ func setupStaticFiles(r *gin.Engine) {
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
r.GET("/icon.png", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "icon.png")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
c.Data(200, "image/png", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/openai/") ||
strings.HasPrefix(path, "/anthropic/") ||
path == "/openai" ||
path == "/anthropic" ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
@@ -256,13 +292,18 @@ func setupStaticFiles(r *gin.Engine) {
func setupSystray(port int) {
systray.Run(func() {
icon, err := embedfs.Assets.ReadFile("assets/icon.png")
var icon []byte
var err error
if runtime.GOOS == "windows" {
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
} else {
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
}
if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error()))
zapLogger.Error("无法加载托盘图标", zap.Error(err))
}
systray.SetIcon(icon)
systray.SetTitle("Nex Gateway")
systray.SetTooltip("AI Gateway")
systray.SetTooltip(appTooltip)
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
systray.AddSeparator()
@@ -271,17 +312,15 @@ func setupSystray(port int) {
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
mPort.Disable()
systray.AddSeparator()
mAbout := systray.AddMenuItem("关于", "")
systray.AddSeparator()
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
go func() {
for {
select {
case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port))
case <-mAbout.ClickedCh:
showAbout()
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("打开浏览器失败", zap.Error(err))
}
case <-mQuit.ClickedCh:
doShutdown()
systray.Quit()
@@ -300,7 +339,9 @@ func doShutdown() {
if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.Shutdown(ctx)
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
}
if shutdownCancel != nil {
@@ -338,8 +379,8 @@ func (s *SingletonLock) Lock() error {
return nil
}
func (s *SingletonLock) Unlock() {
s.flock.Unlock()
func (s *SingletonLock) Unlock() error {
return s.flock.Unlock()
}
func openBrowser(url string) error {
@@ -366,28 +407,3 @@ func openBrowser(url string) error {
return cmd.Start()
}
func showError(title, message string) {
switch runtime.GOOS {
case "darwin":
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, message, title)
exec.Command("osascript", "-e", script).Run()
case "windows":
exec.Command("msg", "*", message).Run()
case "linux":
exec.Command("zenity", "--error", fmt.Sprintf("--title=%s", title), fmt.Sprintf("--text=%s", message)).Run()
}
}
func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
switch runtime.GOOS {
case "darwin":
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, message)
exec.Command("osascript", "-e", script).Run()
case "windows":
exec.Command("msg", "*", message).Run()
case "linux":
exec.Command("zenity", "--info", "--title=关于 Nex Gateway", fmt.Sprintf("--text=%s", message)).Run()
}
}

View File

@@ -0,0 +1,61 @@
//go:build windows
package main
import (
"errors"
"syscall"
"testing"
)
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
t.Helper()
old := callMessageBoxW
callMessageBoxW = fn
t.Cleanup(func() {
callMessageBoxW = old
})
}
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
if err == nil {
t.Fatal("包含 NUL 字符时应该返回错误")
}
}
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 1, syscall.Errno(123)
})
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
}
}
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
err := messageBox("测试标题", "测试消息", mbIconInformation)
if !errors.Is(err, syscall.Errno(5)) {
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
}
}
func TestShowError_WindowsBranch(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
defer func() {
if recovered := recover(); recovered != nil {
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
}
}()
showError("测试错误", "这是一条测试错误消息")
}

View File

@@ -0,0 +1,9 @@
package main
const (
appName = "Nex"
appTooltip = appName
appDescription = "AI Gateway - 统一的大模型 API 网关"
// #nosec G101 -- 项目官网地址不是凭据
appWebsite = "https://github.com/nex/gateway"
)

View File

@@ -0,0 +1,13 @@
package main
import "testing"
func TestDesktopMetadata(t *testing.T) {
if appName != "Nex" {
t.Fatalf("appName = %q, want %q", appName, "Nex")
}
if appTooltip != appName {
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
}
}

View File

@@ -1,6 +1,7 @@
package main
import (
"errors"
"net"
"net/http"
"testing"
@@ -21,19 +22,12 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) {
port := 19827
listener, err := net.Listen("tcp", ":19827")
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
defer listener.Close()
go func() {
conn, err := listener.Accept()
if err == nil {
conn.Close()
}
}()
time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port)
@@ -47,13 +41,19 @@ func TestCheckPortOccupied(t *testing.T) {
func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828
listener, err := net.Listen("tcp", ":19828")
listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil {
t.Fatalf("无法启动测试服务器: %v", err)
}
server := &http.Server{}
go server.Serve(listener)
server := &http.Server{ReadHeaderTimeout: time.Second}
defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)

View File

@@ -0,0 +1,44 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"testing/fstest"
"nex/backend/internal/handler"
"github.com/gin-gonic/gin"
)
func TestSetupRoutes_VersionDoesNotFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
setupStaticFilesWithFS(r, fstest.MapFS{
"index.html": {Data: []byte("<html>fallback</html>")},
})
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
if contentType := w.Header().Get("Content-Type"); contentType == "text/html; charset=utf-8" {
t.Fatalf("版本接口不应返回 SPA fallback HTML")
}
var result map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
for _, key := range []string{"version", "commit", "build_time"} {
if result[key] == "" {
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
}
}
}

View File

@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
}
defer lock.Unlock()
defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
}
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err)
}
defer lock1.Unlock()
defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath)
err := lock2.Lock()
if err == nil {
lock2.Unlock()
if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil")
}
}
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err)
}
lock1.Unlock()
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err)
}
lock2.Unlock()
if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
lock.Unlock()
if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
}

View File

@@ -1,73 +1,26 @@
package main
import (
"io/fs"
"net/http"
"net/http/httptest"
"strings"
"testing"
"testing/fstest"
"github.com/gin-gonic/gin"
"nex/embedfs"
)
func TestSetupStaticFiles(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
}
if strings.HasSuffix(path, ".css") {
return "text/css"
}
if strings.HasSuffix(path, ".svg") {
return "image/svg+xml"
}
return "application/octet-stream"
}
r := gin.New()
r.GET("/assets/*filepath", func(c *gin.Context) {
filepath := c.Param("filepath")
data, err := fs.ReadFile(distFS, "assets"+filepath)
if err != nil {
c.Status(404)
return
}
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
}
data, err := fs.ReadFile(distFS, "index.html")
if err != nil {
c.Status(500)
return
}
c.Data(200, "text/html; charset=utf-8", data)
})
setupStaticFilesWithFS(r, distFS)
t.Run("API 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
@@ -79,6 +32,32 @@ func TestSetupStaticFiles(t *testing.T) {
}
})
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/anthropic/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("SPA fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
@@ -121,3 +100,139 @@ func TestSetupStaticFiles(t *testing.T) {
t.Log("静态文件服务测试通过")
}
func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupStaticFilesWithFS(r, fstest.MapFS{
"icon.png": {Data: []byte("png")},
"index.html": {Data: []byte("<html>fallback</html>")},
})
req := httptest.NewRequest("GET", "/icon.png", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
}
if w.Header().Get("Content-Type") != "image/png" {
t.Fatalf("期望 Content-Type image/png, 实际 %s", w.Header().Get("Content-Type"))
}
if w.Body.String() != "png" {
t.Fatalf("期望返回 PNG 内容,实际 %q", w.Body.String())
}
}
func TestWithProtocolAndStaticRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
var gotProtocol string
var gotPath string
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
setupStaticFilesWithFS(r, distFS)
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "openai" {
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
}
if gotPath != "/v1/chat/completions" {
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
}
})
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "anthropic" {
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
}
if gotPath != "/v1/messages" {
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
}
})
t.Run("Static assets are not hijacked", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("GET", "/assets/test.js", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if gotProtocol != "" || gotPath != "" {
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
if w.Code == http.StatusOK {
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
}
return
}
if w.Code != http.StatusNotFound {
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
}
})
t.Run("SPA path keeps fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
t.Errorf("期望返回 HTML实际 %s", w.Header().Get("Content-Type"))
}
})
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/unknown", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
}
if gotProtocol != "openai" || gotPath != "/unknown" {
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
})
}

View File

@@ -3,7 +3,6 @@ package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
@@ -23,18 +22,19 @@ import (
"nex/backend/internal/provider"
"nex/backend/internal/repository"
"nex/backend/internal/service"
"nex/backend/pkg/buildinfo"
pkgLogger "nex/backend/pkg/logger"
)
func main() {
minimalLogger := pkgLogger.NewMinimal()
cfg, err := config.LoadConfig()
if err != nil {
log.Fatalf("加载配置失败: %v", err)
minimalLogger.Fatal("加载配置失败", zap.Error(err))
}
cfg.PrintSummary()
zapLogger, err := pkgLogger.New(pkgLogger.Config{
zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
Level: cfg.Log.Level,
Path: cfg.Log.Path,
MaxSize: cfg.Log.MaxSize,
@@ -43,13 +43,19 @@ func main() {
Compress: cfg.Log.Compress,
})
if err != nil {
log.Fatalf("初始化日志失败: %v", err)
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
}
defer zapLogger.Sync()
defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger)
db, err := database.Init(&cfg.Database, zapLogger)
if err != nil {
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
}
defer database.Close(db)
@@ -74,19 +80,20 @@ func main() {
registry := conversion.NewMemoryRegistry()
if err := registry.Register(openai.NewAdapter()); err != nil {
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
}
if err := registry.Register(anthropic.NewAdapter()); err != nil {
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
}
engine := conversion.NewConversionEngine(registry, zapLogger)
providerClient := provider.NewClient()
providerClient := provider.NewClient(zapLogger)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService)
versionHandler := handler.NewVersionHandler()
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -96,7 +103,7 @@ func main() {
r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
@@ -106,9 +113,13 @@ func main() {
}
go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
zapLogger.Info("AI Gateway 启动",
zap.String("addr", srv.Addr),
zap.String("version", buildinfo.Version()),
zap.String("commit", buildinfo.Commit()),
zap.String("build_time", buildinfo.BuildTime()))
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
zapLogger.Fatal("服务器启动失败", zap.Error(err))
}
}()
@@ -122,7 +133,7 @@ func main() {
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
zapLogger.Fatal("服务器强制关闭", zap.Error(err))
}
statsBuffer.Stop()
@@ -130,8 +141,9 @@ func main() {
zapLogger.Info("服务器已关闭")
}
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
r.GET("/api/version", versionHandler.GetVersion)
providers := r.Group("/api/providers")
{

View File

@@ -0,0 +1,37 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"nex/backend/internal/handler"
"github.com/gin-gonic/gin"
)
func TestSetupRoutes_Version(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
var result map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
for _, key := range []string{"version", "commit", "build_time"} {
if result[key] == "" {
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
}
}
}

View File

@@ -1,6 +1,7 @@
package config
import (
"errors"
"fmt"
"os"
"path/filepath"
@@ -11,6 +12,7 @@ import (
"github.com/mitchellh/mapstructure"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
appErrors "nex/backend/pkg/errors"
@@ -57,7 +59,10 @@ type LogConfig struct {
// DefaultConfig returns default config values
func DefaultConfig() *Config {
// Use home dir for default paths
homeDir, _ := os.UserHomeDir()
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
return &Config{
@@ -96,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err
}
configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil {
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
@@ -122,7 +127,10 @@ func GetConfigPath() (string, error) {
// setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) {
homeDir, _ := os.UserHomeDir()
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826)
@@ -176,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定
v.BindPFlag("server.port", flagSet.Lookup("server-port"))
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout"))
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
v.BindPFlag("log.level", flagSet.Lookup("log-level"))
v.BindPFlag("log.path", flagSet.Lookup("log-path"))
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size"))
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups"))
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age"))
v.BindPFlag("log.compress", flagSet.Lookup("log-compress"))
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
}
// setupEnv 绑定环境变量
@@ -217,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
// 配置文件不存在,创建默认配置文件
if err := v.SafeWriteConfig(); err != nil {
// 忽略写入错误(可能目录已存在等)
writeErr := v.SafeWriteConfigAs(configPath)
if writeErr == nil {
return nil
}
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
}
return nil
}
@@ -245,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
flagSet.Parse(os.Args[1:])
if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
@@ -294,11 +317,11 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err)
}
return os.WriteFile(configPath, data, 0600)
return os.WriteFile(configPath, data, 0o600)
}
// Validate validates the config
@@ -311,22 +334,24 @@ func (c *Config) Validate() error {
}
// PrintSummary 打印配置摘要
func (c *Config) PrintSummary() {
fmt.Println("\nAI Gateway 启动配置")
fmt.Println("==================")
fmt.Printf("服务器端口: %d\n", c.Server.Port)
func (c *Config) PrintSummary(logger *zap.Logger) {
logger.Info("AI Gateway 启动配置",
zap.Int("server_port", c.Server.Port),
zap.String("database_driver", c.Database.Driver),
zap.String("log_level", c.Log.Level),
)
if c.Database.Driver == "mysql" {
fmt.Printf("数据库类型: mysql\n")
fmt.Printf("数据库地址: %s:%d/%s\n", c.Database.Host, c.Database.Port, c.Database.DBName)
logger.Info("数据库配置",
zap.String("driver", "mysql"),
zap.String("host", c.Database.Host),
zap.Int("port", c.Database.Port),
zap.String("database", c.Database.DBName),
)
} else {
fmt.Printf("数据库类型: sqlite\n")
fmt.Printf("数据库路径: %s\n", c.Database.Path)
logger.Info("数据库配置",
zap.String("driver", "sqlite"),
zap.String("path", c.Database.Path),
)
}
fmt.Printf("日志级别: %s\n", c.Log.Level)
fmt.Println("\n配置来源:")
configPath, _ := GetConfigPath()
fmt.Printf(" 配置文件: %s\n", configPath)
fmt.Println(" 环境变量: 待统计")
fmt.Println(" CLI 参数: 待统计")
fmt.Println()
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
@@ -171,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
err := cfg.Validate()
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
if err != nil {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
@@ -233,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg)
require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644)
err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err)
// 加载配置
@@ -302,7 +305,7 @@ func TestPrintSummary(t *testing.T) {
t.Run("SQLite模式摘要", func(t *testing.T) {
cfg := DefaultConfig()
assert.NotPanics(t, func() {
cfg.PrintSummary()
cfg.PrintSummary(zap.NewNop())
})
})
t.Run("MySQL模式摘要", func(t *testing.T) {
@@ -313,7 +316,7 @@ func TestPrintSummary(t *testing.T) {
cfg.Database.User = "nex"
cfg.Database.DBName = "nex"
assert.NotPanics(t, func() {
cfg.PrintSummary()
cfg.PrintSummary(zap.NewNop())
})
})
}

View File

@@ -6,15 +6,15 @@ import (
// Provider 供应商模型
type Provider struct {
ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"`
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"`
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
}
// Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
@@ -29,8 +29,8 @@ type Model struct {
// UsageStats 用量统计
type UsageStats struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
ProviderID string `gorm:"not null;index" json:"provider_id"`
ModelName string `gorm:"not null;index" json:"model_name"`
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
RequestCount int `gorm:"default:0" json:"request_count"`
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
}
@@ -47,4 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string {
return "usage_stats"
}

View File

@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message,
},
}
body, _ := json.Marshal(errMsg)
body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode
}
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
}
return current, rewriteFunc, nil
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType {
case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
default:
return body, nil

View File

@@ -2,6 +2,7 @@ package anthropic
import (
"encoding/json"
"errors"
"testing"
"nex/backend/internal/conversion"
@@ -48,14 +49,36 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
}
}
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
a := NewAdapter()
// docs/api_reference/anthropic defines messages and models under /v1.
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/v1/messages", conversion.InterfaceTypeChat},
{"/v1/models", conversion.InterfaceTypeModels},
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
{"/messages", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
nativePath string
name string
nativePath string
interfaceType conversion.InterfaceType
expected string
expected string
}{
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
@@ -102,9 +125,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
name string
interfaceType conversion.InterfaceType
expected bool
expected bool
}{
{"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true},
@@ -141,8 +164,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
@@ -150,24 +173,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
}
@@ -178,8 +201,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
@@ -187,24 +210,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError)
require.True(t, ok)
var convErr *conversion.ConversionError
require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
})
}

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages {
decoded := decodeMessage(msg)
decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...)
}
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
}
// decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage {
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role {
case "user":
blocks := decodeContentBlocks(msg.Content)
blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock
for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
}
return result
return result, nil
case "assistant":
blocks := decodeContentBlocks(msg.Content)
blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock(""))
}
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
}
return nil
return nil, nil
}
// decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock {
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) {
case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any:
var blocks []canonical.ContentBlock
for _, item := range v {
if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m)
block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil {
blocks = append(blocks, *block)
}
}
}
if len(blocks) > 0 {
return blocks
return blocks, nil
}
return []canonical.ContentBlock{canonical.NewTextBlock("")}
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")}
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
}
}
// decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
t, _ := m["type"].(string)
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t {
case "text":
text, _ := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text}
text, ok := m["text"].(string)
if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use":
id, _ := m["id"].(string)
name, _ := m["name"].(string)
input, _ := json.Marshal(m["input"])
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
id, ok := m["id"].(string)
if !ok {
id = ""
}
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result":
toolUseID, _ := m["tool_use_id"].(string)
toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false
if ie, ok := m["is_error"].(bool); ok {
isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string:
content = json.RawMessage(fmt.Sprintf("%q", cv))
default:
content, _ = json.Marshal(cv)
encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
}
} else {
content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID,
Content: content,
IsError: &isErr,
}
}, nil
case "thinking":
thinking, _ := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
thinking, ok := m["thinking"].(string)
if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking":
// 丢弃
return nil
return nil, nil
}
return nil
return nil, nil
}
// decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny()
}
case map[string]any:
t, _ := v["type"].(string)
t, ok := v["type"].(string)
if !ok {
return nil
}
switch t {
case "auto":
return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any":
return canonical.NewToolChoiceAny()
case "tool":
name, _ := v["name"].(string)
name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
}

View File

@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
result = append(result, m)
case "tool_result":
m := map[string]any{
"type": "tool_result",
"type": "tool_result",
"tool_use_id": b.ToolUseID,
}
if b.Content != nil {
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
}
result := map[string]any{
"id": resp.ID,
"type": "message",
"role": "assistant",
"model": resp.Model,
"content": blocks,
"id": resp.ID,
"type": "message",
"role": "assistant",
"model": resp.Model,
"content": blocks,
"stop_reason": sr,
"stop_sequence": nil,
"usage": usage,

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1)
}
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息
foundToolResult := false
for _, m := range msgs {
msgMap := m.(map[string]any)
msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any)
if ok {
for _, c := range content {
block := c.(map[string]any)
block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" {
foundToolResult = true
}
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"])
}
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any)
content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1)
block := content[0].(map[string]any)
block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"])
}
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any)
data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1)
model := data[0].(map[string]any)
model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any)
userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any)
content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2)
}
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning)
}
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any)
content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1)
block := content[0].(map[string]any)
block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk
if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...)
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil
}
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") {
switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") {
case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
}
eventType = ""
eventData = ""
} else if line == "" {
// SSE 事件分隔符
case line == "":
continue
}
}
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
// processContentBlockStart 处理内容块开始事件
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
var raw struct {
Index int `json:"index"`
Index int `json:"index"`
ContentBlock struct {
Type string `json:"type"`
Text string `json:"text"`

View File

@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
checkValue string
}{
{
name: "text_delta",
deltaType: "text_delta",
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
name: "text_delta",
deltaType: "text_delta",
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
checkField: "text",
checkValue: "你好",
},
{
name: "input_json_delta",
deltaType: "input_json_delta",
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
name: "input_json_delta",
deltaType: "input_json_delta",
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
checkField: "partial_json",
checkValue: "{\"key\":",
},
{
name: "thinking_delta",
deltaType: "thinking_delta",
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
name: "thinking_delta",
deltaType: "thinking_delta",
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
checkField: "thinking",
checkValue: "思考中",
},
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
payload := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": tt.deltaData,
"delta": tt.deltaData,
}
raw := makeAnthropicEvent("content_block_delta", payload)
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
"type": "content_block_start",
"index": 3,
"content_block": map[string]any{
"type": "web_search_tool_result",
"type": "web_search_tool_result",
"tool_use_id": "search_1",
},
}
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{
"type": "citations_delta",
"citation": map[string]any{"title": "ref1"},
"type": "citations_delta",
"citation": map[string]any{"title": "ref1"},
},
}
raw := makeAnthropicEvent("content_block_delta", payload)
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
},
}
deltaPayload1 := map[string]any{
"type": "message_delta",
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 25},
}
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
assert.Equal(t, 25, events[0].Usage.OutputTokens)
deltaPayload2 := map[string]any{
"type": "message_delta",
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 30},
}

View File

@@ -50,16 +50,24 @@ func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent)
}
if event.Message != nil {
msg := map[string]any{
"id": event.Message.ID,
"model": event.Message.Model,
"role": "assistant",
"id": event.Message.ID,
"type": "message",
"role": "assistant",
"content": []any{},
"model": event.Message.Model,
"stop_reason": nil,
"stop_sequence": nil,
}
if event.Message.Usage != nil {
usage := map[string]any{
msg["usage"] = map[string]any{
"input_tokens": event.Message.Usage.InputTokens,
"output_tokens": event.Message.Usage.OutputTokens,
}
msg["usage"] = usage
} else {
msg["usage"] = map[string]any{
"input_tokens": 0,
"output_tokens": 0,
}
}
payload["message"] = msg
}
@@ -147,6 +155,10 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
payload["usage"] = map[string]any{
"output_tokens": event.Usage.OutputTokens,
}
} else {
payload["usage"] = map[string]any{
"output_tokens": 0,
}
}
return e.marshalEvent("message_delta", payload)
}

View File

@@ -21,8 +21,55 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
assert.Contains(t, s, "data: ")
assert.Contains(t, s, "msg_1")
assert.Contains(t, s, "claude-3")
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "msg_1", msg["id"])
assert.Equal(t, "message", msg["type"])
assert.Equal(t, "assistant", msg["role"])
assert.Equal(t, []any{}, msg["content"])
assert.Equal(t, "claude-3", msg["model"])
assert.Nil(t, msg["stop_reason"])
assert.Nil(t, msg["stop_sequence"])
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(0), usage["input_tokens"])
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(50), usage["output_tokens"])
}
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
@@ -80,7 +127,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"])
}
@@ -107,7 +155,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"])
@@ -131,7 +180,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break
}
}
cb := payload["content_block"].(map[string]any)
cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"])
}
@@ -173,8 +223,13 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break
}
}
delta := payload["delta"].(map[string]any)
delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"])
usage, oku := payload["usage"].(map[string]any)
require.True(t, oku, "message_delta SHALL always include usage")
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
@@ -199,7 +254,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break
}
}
u := payload["usage"].(map[string]any)
u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"])
}

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
}
func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil)
blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text)
}
func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello")
blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text)
}
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
r, ok := result.(map[string]any)
require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
})
}
}
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any)
tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1)
tool := tools[0].(map[string]any)
tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any)
tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"])
}
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: &cacheRead,
InputTokens: 100,
OutputTokens: 50,
CacheReadTokens: &cacheRead,
CacheCreationTokens: &cacheCreation,
},
}
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -6,22 +6,22 @@ import (
// MessagesRequest Anthropic Messages 请求
type MessagesRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System any `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Metadata *RequestMetadata `json:"metadata,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
OutputConfig *OutputConfig `json:"output_config,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
Container any `json:"container,omitempty"`
Model string `json:"model"`
Messages []Message `json:"messages"`
System any `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Metadata *RequestMetadata `json:"metadata,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
OutputConfig *OutputConfig `json:"output_config,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
Container any `json:"container,omitempty"`
}
// RequestMetadata 请求元数据
@@ -122,8 +122,8 @@ type ContentBlock struct {
// ResponseUsage 响应用量
type ResponseUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
}

View File

@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
// EmbeddingData 嵌入数据项
type EmbeddingData struct {
Index int `json:"index"`
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
Index int `json:"index"`
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
}
// EmbeddingUsage 嵌入用量

View File

@@ -18,17 +18,17 @@ const (
type DeltaType string
const (
DeltaTypeText DeltaType = "text_delta"
DeltaTypeInputJSON DeltaType = "input_json_delta"
DeltaTypeThinking DeltaType = "thinking_delta"
DeltaTypeText DeltaType = "text_delta"
DeltaTypeInputJSON DeltaType = "input_json_delta"
DeltaTypeThinking DeltaType = "thinking_delta"
)
// StreamDelta 流式增量联合体
type StreamDelta struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
}
// StreamContentBlock 流式内容块联合体
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
Message *StreamMessage `json:"message,omitempty"`
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
Index *int `json:"index,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"`
// MessageDeltaEvent
StopReason *StopReason `json:"stop_reason,omitempty"`
StopReason *StopReason `json:"stop_reason,omitempty"`
Usage *CanonicalUsage `json:"usage,omitempty"`
// ErrorEvent

View File

@@ -40,8 +40,8 @@ type ContentBlock struct {
Text string `json:"text,omitempty"`
// ToolUseBlock
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
// ToolResultBlock
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
// OutputFormat 输出格式联合体
type OutputFormat struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
Type string `json:"type"`
Name string `json:"name,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// CanonicalRequest 规范请求
type CanonicalRequest struct {
Model string `json:"model"`
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
Model string `json:"model"`
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
Messages []CanonicalMessage `json:"messages"`
Tools []CanonicalTool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Parameters RequestParameters `json:"parameters"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Stream bool `json:"stream"`
UserID string `json:"user_id,omitempty"`
OutputFormat *OutputFormat `json:"output_format,omitempty"`
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
Tools []CanonicalTool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Parameters RequestParameters `json:"parameters"`
Thinking *ThinkingConfig `json:"thinking,omitempty"`
Stream bool `json:"stream"`
UserID string `json:"user_id,omitempty"`
OutputFormat *OutputFormat `json:"output_format,omitempty"`
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
}
// CanonicalUsage 规范用量
type CanonicalUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
}
// CanonicalResponse 规范响应
type CanonicalResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason *StopReason `json:"stop_reason,omitempty"`
Usage CanonicalUsage `json:"usage"`
ID string `json:"id"`
Model string `json:"model"`
Content []ContentBlock `json:"content"`
StopReason *StopReason `json:"stop_reason,omitempty"`
Usage CanonicalUsage `json:"usage"`
}
// GetSystemString 获取系统消息字符串

View File

@@ -10,9 +10,9 @@ import (
func TestGetSystemString(t *testing.T) {
tests := []struct {
name string
system any
want string
name string
system any
want string
}{
{"string", "hello", "hello"},
{"nil", nil, ""},
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
func TestCanonicalResponse_RoundTrip(t *testing.T) {
sr := StopReasonEndTurn
resp := &CanonicalResponse{
ID: "resp-1",
Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")},
ID: "resp-1",
Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")},
StopReason: &sr,
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
}
data, err := json.Marshal(resp)

View File

@@ -3,10 +3,14 @@ package conversion
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"nex/backend/internal/conversion/canonical"
pkglogger "nex/backend/pkg/logger"
)
// HTTPRequestSpec HTTP 请求规格
@@ -33,13 +37,10 @@ type ConversionEngine struct {
// NewConversionEngine 创建转换引擎
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
if logger == nil {
logger = zap.L()
}
return &ConversionEngine{
registry: registry,
middlewareChain: NewMiddlewareChain(),
logger: logger,
logger: pkglogger.WithModule(logger, "conversion.engine"),
}
}
@@ -72,7 +73,7 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
// ConvertHttpRequest 转换 HTTP 请求
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
nativePath := spec.URL
nativePath, rawQuery := splitRequestPath(spec.URL)
if e.IsPassthrough(clientProtocol, providerProtocol) {
providerAdapter, err := e.registry.Get(providerProtocol)
@@ -90,15 +91,18 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
zap.String("error", err.Error()),
zap.Error(err),
zap.String("interface", string(interfaceType)))
rewrittenBody = spec.Body
}
}
}
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
return &HTTPRequestSpec{
URL: provider.BaseURL + nativePath,
URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method,
Headers: providerAdapter.BuildHeaders(provider),
Body: rewrittenBody,
@@ -115,7 +119,8 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil {
@@ -123,7 +128,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
}
return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl,
URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method,
Headers: providerHeaders,
Body: providerBody,
@@ -135,24 +140,21 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return &spec, nil
adapter, getErr := e.registry.Get(clientProtocol)
if getErr == nil {
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if rewriteErr != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.Error(rewriteErr),
zap.String("interface", string(interfaceType)))
} else {
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
}
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.String("error", err.Error()),
zap.String("interface", string(interfaceType)))
return &spec, nil
}
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
return &spec, nil
}
@@ -183,11 +185,10 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol)
if err != nil {
return NewPassthroughStreamConverter(), nil
adapter, getErr := e.registry.Get(clientProtocol)
if getErr == nil {
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
}
return NewPassthroughStreamConverter(), nil
}
@@ -202,9 +203,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
}
ctx := ConversionContext{
ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat,
Timestamp: time.Now(),
ConversionID: uuid.New().String(),
InterfaceType: interfaceType,
Timestamp: time.Now(),
}
return NewCanonicalStreamConverterWithMiddleware(
@@ -273,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
canonicalReq, err := clientAdapter.DecodeRequest(body)
if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
return nil, NewRequestJSONParseError("解码请求失败", err)
}
ctx := NewConversionContext(InterfaceTypeChat)
@@ -281,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
if err != nil {
return nil, err
}
if containsUnsupportedMultimodal(canonicalReq) {
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
}
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
if err != nil {
@@ -292,7 +296,7 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
canonicalResp, err := providerAdapter.DecodeResponse(body)
if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
return nil, NewResponseJSONParseError("解码响应失败", err)
}
if modelOverride != "" {
canonicalResp.Model = modelOverride
@@ -307,12 +311,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
encoded, err := clientAdapter.EncodeModelsResponse(models)
if err != nil {
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
return encoded, nil
@@ -321,12 +325,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
return encoded, nil
@@ -335,7 +339,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil
}
return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -344,7 +348,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil
}
if modelOverride != "" {
@@ -356,21 +360,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil
}
return providerAdapter.EncodeRerankRequest(req, provider)
}
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body)
if err != nil {
return body, nil
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if decodeErr == nil {
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
}
if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
return body, nil
}
// DetectInterfaceType 检测接口类型
@@ -379,6 +384,7 @@ func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string
if err != nil {
return InterfaceTypePassthrough, err
}
nativePath, _ = splitRequestPath(nativePath)
return adapter.DetectInterfaceType(nativePath), nil
}
@@ -392,9 +398,56 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error",
},
}
body, _ := json.Marshal(fallback)
return body, 500, nil
body, marshalErr := json.Marshal(fallback)
if marshalErr == nil {
return body, 500, nil
}
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
}
body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil
}
func splitRequestPath(rawPath string) (string, string) {
path, query, found := strings.Cut(rawPath, "?")
if !found {
return rawPath, ""
}
return path, query
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
if strings.Contains(path, "?") {
return path + "&" + rawQuery
}
return path + "?" + rawQuery
}
func joinBaseURL(baseURL, path string) string {
if baseURL == "" {
return path
}
if path == "" {
return baseURL
}
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
if req == nil {
return false
}
for _, msg := range req.Messages {
for _, block := range msg.Content {
switch block.Type {
case "image", "audio", "video", "file":
return true
}
}
}
return false
}

View File

@@ -0,0 +1,63 @@
package conversion_test
import (
"testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
tests := []struct {
name string
adapter conversion.ProtocolAdapter
clientProtocol string
providerProtocol string
baseURL string
nativePath string
expectedURL string
body []byte
}{
{
name: "openai base url includes version path",
adapter: openai.NewAdapter(),
clientProtocol: "openai",
providerProtocol: "openai",
baseURL: "http://example.com/v1",
nativePath: "/chat/completions",
expectedURL: "http://example.com/v1/chat/completions",
body: []byte(`{"model":"gpt-4","messages":[]}`),
},
{
name: "anthropic native path keeps v1",
adapter: anthropic.NewAdapter(),
clientProtocol: "anthropic",
providerProtocol: "anthropic",
baseURL: "http://example.com",
nativePath: "/v1/messages",
expectedURL: "http://example.com/v1/messages",
body: []byte(`{"model":"claude","messages":[]}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(tt.adapter))
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
URL: tt.nativePath,
Method: "POST",
Body: tt.body,
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
require.NoError(t, err)
assert.Equal(t, tt.expectedURL, out.URL)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestConversionError_WithProviderProtocol(t *testing.T) {
@@ -39,7 +40,7 @@ func TestConversionError_FullBuilder(t *testing.T) {
func TestEngine_Use(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
called := false
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
called = true
@@ -66,7 +67,7 @@ func TestEngine_Use(t *testing.T) {
func TestConvertHttpRequest_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
return nil, errors.New("decode failed")
@@ -82,7 +83,7 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) {
func TestConvertHttpRequest_EncodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("client", false))
providerAdapter := newMockAdapter("provider", false)
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
@@ -98,7 +99,7 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) {
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
@@ -121,7 +122,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
func TestConvertHttpResponse_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
providerAdapter := newMockAdapter("provider", false)
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
return nil, errors.New("decode error")
@@ -135,7 +136,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeEmbeddings
@@ -158,7 +159,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeRerank
@@ -178,7 +179,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) {
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
@@ -196,7 +197,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
@@ -214,7 +215,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.ifaceType = InterfaceTypeModels
providerAdapter := newMockAdapter("provider", false)
@@ -232,7 +233,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
providerAdapter := newMockAdapter("provider", false)
@@ -249,7 +250,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
providerAdapter := newMockAdapter("provider", false)
@@ -324,7 +325,7 @@ var _ = json.Marshal
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
@@ -344,7 +345,7 @@ func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
func TestConvertRerankBody_DecodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
@@ -364,7 +365,7 @@ func TestConvertRerankBody_DecodeError(t *testing.T) {
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
providerAdapter := newMockAdapter("provider", false)

View File

@@ -2,6 +2,7 @@ package conversion
import (
"encoding/json"
"strings"
"testing"
"nex/backend/internal/conversion/canonical"
@@ -38,8 +39,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
}
}
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
@@ -190,20 +191,22 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
// noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器
type noopStreamEncoder struct{}
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
// ============ 测试用例 ============
func TestNewConversionEngine(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
assert.NotNil(t, engine)
assert.Equal(t, registry, engine.GetRegistry())
}
@@ -211,7 +214,7 @@ func TestNewConversionEngine(t *testing.T) {
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
t.Run("nil_logger_uses_global", func(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
assert.NotNil(t, engine.logger)
})
@@ -219,13 +222,14 @@ func TestNewConversionEngine_LoggerInjection(t *testing.T) {
registry := NewMemoryRegistry()
customLogger := zap.NewNop()
engine := NewConversionEngine(registry, customLogger)
assert.Equal(t, customLogger, engine.logger)
assert.NotNil(t, engine.logger)
assert.Contains(t, engine.logger.Name(), "conversion.engine")
})
}
func TestRegisterAdapter(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("test-proto", true)
err := engine.RegisterAdapter(adapter)
@@ -237,7 +241,7 @@ func TestRegisterAdapter(t *testing.T) {
func TestIsPassthrough_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("openai", true)
_ = engine.RegisterAdapter(adapter)
@@ -246,7 +250,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
@@ -255,7 +259,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
func TestIsPassthrough_NoPassthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
assert.False(t, engine.IsPassthrough("custom", "custom"))
@@ -263,7 +267,7 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
func TestDetectInterfaceType(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
adapter := newMockAdapter("test", true)
adapter.ifaceType = InterfaceTypeChat
_ = engine.RegisterAdapter(adapter)
@@ -275,7 +279,7 @@ func TestDetectInterfaceType(t *testing.T) {
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
assert.Error(t, err)
@@ -283,25 +287,39 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
func TestConvertHttpRequest_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
}
_ = engine.RegisterAdapter(openaiAdapter)
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
spec := HTTPRequestSpec{
URL: "/chat/completions",
URL: "/v1/chat/completions",
Method: "POST",
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`),
Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
}
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
assert.Equal(t, spec.Body, result.Body)
assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
}
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client-proto", false)
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
@@ -331,9 +349,80 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
assert.NotNil(t, result.Body)
}
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `"}`), nil
}
require.NoError(t, registry.Register(openaiAdapter))
anthropicAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("anthropic", false),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/v1/messages"
}
return nativePath
},
}
anthropicAdapter.ifaceType = InterfaceTypeChat
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
require.NoError(t, registry.Register(anthropicAdapter))
t.Run("OpenAI to Anthropic", func(t *testing.T) {
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
spec := HTTPRequestSpec{
URL: "/v1/chat/completions",
Method: "POST",
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
}
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
})
t.Run("Anthropic to OpenAI", func(t *testing.T) {
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
spec := HTTPRequestSpec{
URL: "/v1/messages",
Method: "POST",
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
}
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
})
}
type buildURLMockAdapter struct {
*mockProtocolAdapter
buildURLFn func(string, InterfaceType) string
}
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
if m.buildURLFn != nil {
return m.buildURLFn(nativePath, interfaceType)
}
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
}
func TestConvertHttpResponse_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
spec := HTTPResponseSpec{
@@ -349,7 +438,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Passthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
@@ -360,7 +449,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
func TestCreateStreamConverter_Canonical(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("client", false))
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
@@ -372,7 +461,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
func TestEncodeError(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
@@ -384,7 +473,7 @@ func TestEncodeError(t *testing.T) {
func TestEncodeError_NonExistentProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
@@ -417,7 +506,7 @@ func TestRegistry_GetNonExistent(t *testing.T) {
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
clientAdapter := newMockAdapter("client", false)
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
@@ -446,7 +535,7 @@ func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
openaiAdapter := newMockAdapter("openai", true)
@@ -476,7 +565,7 @@ func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := newMockAdapter("openai", true)
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
@@ -495,18 +584,19 @@ func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
_, ok := converter.(*SmartPassthroughStreamConverter)
assert.True(t, ok)
// 验证 chunk 改写
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
// 验证 SSE frame 中的 data JSON 被改写
chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
require.Len(t, chunks, 1)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(chunks[0], &resp))
payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
assert.Equal(t, "openai/gpt-4", resp["model"])
}
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
// provider adapter 解码出含 model 的流式事件
providerAdapter := newMockAdapter("provider", false)
@@ -560,7 +650,7 @@ func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, nil)
engine := NewConversionEngine(registry, zap.NewNop())
providerAdapter := newMockAdapter("provider", false)
providerAdapter.streamDecoderFn = func() StreamDecoder {
@@ -614,6 +704,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
}
return nil
}
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil {
return d.flushFn()
@@ -633,6 +724,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
}
return nil
}
func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil {
return e.flushFn()

View File

@@ -6,17 +6,24 @@ import "fmt"
type ErrorCode string
const (
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
)
const (
ErrorDetailPhase = "phase"
ErrorPhaseRequest = "request"
ErrorPhaseResponse = "response"
)
// ConversionError 协议转换错误
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
}
}
// NewRequestJSONParseError 创建请求 JSON 解析错误。
func NewRequestJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
WithCause(cause)
}
// NewResponseJSONParseError 创建响应 JSON 解析错误。
func NewResponseJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
WithCause(cause)
}
// WithClientProtocol 设置客户端协议
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
e.ClientProtocol = protocol

View File

@@ -4,10 +4,10 @@ package conversion
type InterfaceType string
const (
InterfaceTypeChat InterfaceType = "CHAT"
InterfaceTypeModels InterfaceType = "MODELS"
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
InterfaceTypeRerank InterfaceType = "RERANK"
InterfaceTypeChat InterfaceType = "CHAT"
InterfaceTypeModels InterfaceType = "MODELS"
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
InterfaceTypeRerank InterfaceType = "RERANK"
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
)

View File

@@ -29,27 +29,27 @@ func (a *Adapter) SupportsPassthrough() bool { return true }
// DetectInterfaceType 根据路径检测接口类型
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
switch {
case nativePath == "/chat/completions":
case nativePath == "/v1/chat/completions":
return conversion.InterfaceTypeChat
case nativePath == "/models":
case nativePath == "/v1/models":
return conversion.InterfaceTypeModels
case isModelInfoPath(nativePath):
return conversion.InterfaceTypeModelInfo
case nativePath == "/embeddings":
case nativePath == "/v1/embeddings":
return conversion.InterfaceTypeEmbeddings
case nativePath == "/rerank":
case nativePath == "/v1/rerank":
return conversion.InterfaceTypeRerank
default:
return conversion.InterfaceTypePassthrough
}
}
// isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 /
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/models/") {
if !strings.HasPrefix(path, "/v1/models/") {
return false
}
suffix := path[len("/models/"):]
suffix := path[len("/v1/models/"):]
return suffix != ""
}
@@ -60,6 +60,11 @@ func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.Interface
return "/chat/completions"
case conversion.InterfaceTypeModels:
return "/models"
case conversion.InterfaceTypeModelInfo:
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
return "/models/" + modelID
}
return nativePath
case conversion.InterfaceTypeEmbeddings:
return "/embeddings"
case conversion.InterfaceTypeRerank:
@@ -138,7 +143,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code),
},
}
body, _ := json.Marshal(errMsg)
body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode
}
@@ -218,12 +226,12 @@ func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse)
return encodeRerankResponse(resp)
}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/models/{provider_id}/{model_name}
// ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/models/") {
if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
}
suffix := nativePath[len("/models/"):]
suffix := nativePath[len("/v1/models/"):]
if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID")
}
@@ -248,7 +256,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err
}
rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
}
return current, rewriteFunc, nil
@@ -282,12 +294,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m)
case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel)
encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
}
return json.Marshal(m)
default:

View File

@@ -28,11 +28,11 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
path string
expected conversion.InterfaceType
}{
{"聊天补全", "/chat/completions", conversion.InterfaceTypeChat},
{"模型列表", "/models", conversion.InterfaceTypeModels},
{"模型详情", "/models/gpt-4", conversion.InterfaceTypeModelInfo},
{"嵌入接口", "/embeddings", conversion.InterfaceTypeEmbeddings},
{"重排序接口", "/rerank", conversion.InterfaceTypeRerank},
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
}
@@ -44,19 +44,42 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
}
}
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
a := NewAdapter()
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/chat/completions", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
{"/embeddings", conversion.InterfaceTypePassthrough},
{"/rerank", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
nativePath string
name string
nativePath string
interfaceType conversion.InterfaceType
expected string
expected string
}{
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/models", conversion.InterfaceTypeModels, "/models"},
{"嵌入", "/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
{"重排序", "/rerank", conversion.InterfaceTypeRerank, "/rerank"},
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
{"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
}
@@ -92,9 +115,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter()
tests := []struct {
name string
name string
interfaceType conversion.InterfaceType
expected bool
expected bool
}{
{"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true},
@@ -118,12 +141,12 @@ func TestIsModelInfoPath(t *testing.T) {
path string
expected bool
}{
{"model_info", "/models/gpt-4", true},
{"model_info_with_dots", "/models/gpt-4.1-preview", true},
{"models_list", "/models", false},
{"nested_path", "/models/gpt-4/versions", true},
{"empty_suffix", "/models/", false},
{"unrelated", "/chat/completions", false},
{"model_info", "/v1/models/openai/gpt-4", true},
{"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
{"models_list", "/v1/models", false},
{"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
{"empty_suffix", "/v1/models/", false},
{"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/model", false},
}
@@ -134,6 +157,27 @@ func TestIsModelInfoPath(t *testing.T) {
}
}
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("标准路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", modelID)
})
t.Run("复杂路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
})
t.Run("非模型详情路径报错", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
}
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
a := NewAdapter()
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")

View File

@@ -18,35 +18,35 @@ func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/openai/gpt-4")
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", id)
})
t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/azure/accounts/org/models/gpt-4")
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
})
t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/gpt-4")
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "gpt-4", id)
})
t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/chat/completions")
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
require.Error(t, err)
})
t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/models/")
_, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err)
})
t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/models")
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
@@ -344,12 +344,12 @@ func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
path string
expected bool
}{
{"simple_model_id", "/models/gpt-4", true},
{"unified_model_id_with_slash", "/models/openai/gpt-4", true},
{"models_list", "/models", false},
{"models_list_trailing_slash", "/models/", false},
{"chat_completions", "/chat/completions", false},
{"deeply_nested", "/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
{"simple_model_id", "/v1/models/gpt-4", true},
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
{"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/v1/models/", false},
{"chat_completions", "/v1/chat/completions", false},
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
}
for _, tt := range tests {

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock
for _, item := range v {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
t, ok := m["type"].(string)
if !ok {
continue
}
switch t {
case "text":
text, _ := m["text"].(string)
text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
// contentPart 内容部分
type contentPart struct {
Type string
Text string
Refusal string
Type string
Text string
Refusal string
}
// decodeContentParts 解码内容部分
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart
for _, item := range parts {
if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string)
t, ok := m["type"].(string)
if !ok {
continue
}
switch t {
case "text":
text, _ := m["text"].(string)
text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text})
case "refusal":
refusal, _ := m["refusal"].(string)
refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
}
}
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny()
}
case map[string]any:
t, _ := v["type"].(string)
t, ok := v["type"].(string)
if !ok {
return nil
}
switch t {
case "function":
if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string)
name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
case "custom":
if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string)
name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name)
}
case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string)
mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" {
return canonical.NewToolChoiceAny()
}
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
case map[string]any:
if name, ok := v["name"].(string); ok {
req.ToolChoice = map[string]any{
"type": "function",
"type": "function",
"function": map[string]any{"name": name},
}
}

View File

@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
"object": "list",
"data": data,
"model": resp.Model,
"usage": resp.Usage,
"usage": resp.Usage,
})
}

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"])
}
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any)
msgs, ok := result["messages"].([]any)
require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok)
assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any)
tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"])
}
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
func TestEncodeResponse_Basic(t *testing.T) {
sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{
ID: "resp-1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
ID: "resp-1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
}
body, err := encodeResponse(resp)
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any)
choice := choices[0].(map[string]any)
msg := choice["message"].(map[string]any)
choices, ok := result["choices"].([]any)
require.True(t, ok)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"])
}
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
sr := canonical.StopReasonToolUse
input := json.RawMessage(`{"q":"test"}`)
resp := &canonical.CanonicalResponse{
ID: "resp-2",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
ID: "resp-2",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
StopReason: &sr,
}
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any)
choices, okc := result["choices"].([]any)
require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok)
assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"])
data := result["data"].([]any)
data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2)
}
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any)
choices, okch := result["choices"].([]any)
require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"])
}

View File

@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-1",
"object": "chat.completion.chunk",
"model": "gpt-4",
"id": "chatcmpl-1",
"object": "chat.completion.chunk",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
idx := 0
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
// 先发送一个文本 chunk
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
// 连续两个 refusal delta chunk
for _, text := range []string{"拒绝", "原因"} {
chunk := map[string]any{
"id": "chatcmpl-1",
"model": "gpt-4",
"id": "chatcmpl-1",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx0 := 0
chunk1 := map[string]any{
"id": "chatcmpl-mt",
"model": "gpt-4",
"id": "chatcmpl-mt",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx1 := 1
chunk2 := map[string]any{
"id": "chatcmpl-mt",
"model": "gpt-4",
"id": "chatcmpl-mt",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
d := NewStreamDecoder()
chunk1 := map[string]any{
"id": "chatcmpl-multi",
"model": "gpt-4",
"id": "chatcmpl-multi",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
},
}
chunk2 := map[string]any{
"id": "chatcmpl-multi",
"model": "gpt-4",
"id": "chatcmpl-multi",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
d := NewStreamDecoder()
chunk := map[string]any{
"id": "chatcmpl-utf8",
"model": "gpt-4",
"id": "chatcmpl-utf8",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
idx := 0
chunk1 := map[string]any{
"id": "chatcmpl-tc",
"model": "gpt-4",
"id": "chatcmpl-tc",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
},
}
chunk2 := map[string]any{
"id": "chatcmpl-tc",
"model": "gpt-4",
"id": "chatcmpl-tc",
"model": "gpt-4",
"choices": []any{
map[string]any{
"index": 0,

View File

@@ -10,9 +10,9 @@ import (
// StreamEncoder OpenAI 流式编码器
type StreamEncoder struct {
bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int
nextToolCallIndex int
bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int
nextToolCallIndex int
}
// NewStreamEncoder 创建 OpenAI 流式编码器
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
chunk := map[string]any{
"choices": []map[string]any{{
"index": 0,
"delta": delta,
"index": 0,
"delta": delta,
}},
}
return e.marshalChunk(chunk)

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any)
choices, okch := payload["choices"].([]any)
require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"])
}

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any)
results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1)
}
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
reasoning := 20
sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{
ID: "r1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
ID: "r1",
Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr,
Usage: canonical.CanonicalUsage{
InputTokens: 100,
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any)
usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any
require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any)
choice := choices[0].(map[string]any)
choices, okch := result["choices"].([]any)
require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"])
})
}

View File

@@ -4,42 +4,42 @@ import "encoding/json"
// ChatCompletionRequest OpenAI Chat Completion 请求
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
User string `json:"user,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
N *int `json:"n,omitempty"`
Seed *int `json:"seed,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
User string `json:"user,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
N *int `json:"n,omitempty"`
Seed *int `json:"seed,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"`
// 已废弃字段
Functions []FunctionDef `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
Functions []FunctionDef `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
}
// Message OpenAI 消息
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Refusal string `json:"refusal,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Role string `json:"role"`
Content any `json:"content"`
Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Refusal string `json:"refusal,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
// 已废弃
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
@@ -88,8 +88,8 @@ type FunctionDef struct {
// ResponseFormat OpenAI 响应格式
type ResponseFormat struct {
Type string `json:"type"`
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
Type string `json:"type"`
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
}
// JSONSchemaDef JSON Schema 定义
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
// Choice OpenAI 选择项
type Choice struct {
Index int `json:"index"`
Index int `json:"index"`
Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"`
FinishReason *string `json:"finish_reason"`
@@ -127,10 +127,10 @@ type Choice struct {
// Usage OpenAI 用量
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
}

View File

@@ -1,6 +1,11 @@
package conversion
import "nex/backend/internal/conversion/canonical"
import (
"bytes"
"strings"
"nex/backend/internal/conversion/canonical"
)
// StreamDecoder 流式解码器接口
type StreamDecoder interface {
@@ -39,11 +44,12 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
}
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
// 逐 chunk 改写 model 字段
// 按 SSE frame 改写 data JSON 中的 model 字段
type SmartPassthroughStreamConverter struct {
adapter ProtocolAdapter
modelOverride string
interfaceType InterfaceType
buffer []byte
}
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
@@ -55,24 +61,45 @@ func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride s
}
}
// ProcessChunk 改写 chunk 中的 model 字段
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
if len(rawChunk) == 0 {
return nil
}
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
if err != nil {
// 改写失败,返回原始 chunk
return [][]byte{rawChunk}
}
c.buffer = append(c.buffer, rawChunk...)
frames, rest := splitSSEFrames(c.buffer)
c.buffer = rest
return [][]byte{rewrittenChunk}
result := make([][]byte, 0, len(frames))
for _, frame := range frames {
result = append(result, c.rewriteFrame(frame))
}
return result
}
// Flush 无缓冲数据
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
payload, ok := sseFrameDataPayload(frame)
if !ok || strings.TrimSpace(payload) == "[DONE]" {
return frame
}
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
if err != nil {
return frame
}
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
}
// Flush 输出未形成完整 frame 的剩余数据
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
return nil
if len(c.buffer) == 0 {
return nil
}
frame := append([]byte(nil), c.buffer...)
c.buffer = nil
return [][]byte{c.rewriteFrame(frame)}
}
// CanonicalStreamConverter 跨协议规范流式转换器
@@ -153,3 +180,86 @@ func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.Canonical
event.Message.Model = c.modelOverride
}
}
func splitSSEFrames(data []byte) ([][]byte, []byte) {
var frames [][]byte
for len(data) > 0 {
idx, sepLen := findSSEFrameSeparator(data)
if idx < 0 {
break
}
end := idx + sepLen
frames = append(frames, append([]byte(nil), data[:end]...))
data = data[end:]
}
return frames, data
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
lineEnding, separator := sseLineEnding(frame)
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
out := make([]string, 0, len(lines)+1)
dataWritten := false
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
if !dataWritten {
for _, dataLine := range strings.Split(data, "\n") {
out = append(out, "data: "+dataLine)
}
dataWritten = true
}
continue
}
out = append(out, line)
}
if !dataWritten {
out = append(out, "data: "+data)
}
return []byte(strings.Join(out, lineEnding) + separator)
}
func sseLineEnding(frame []byte) (string, string) {
if bytes.Contains(frame, []byte("\r\n")) {
return "\r\n", "\r\n\r\n"
}
return "\n", "\n\n"
}

View File

@@ -2,7 +2,6 @@ package database
import (
"fmt"
"log"
"os"
"path/filepath"
"runtime"
@@ -12,22 +11,24 @@ import (
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"nex/backend/internal/config"
pkglogger "nex/backend/pkg/logger"
)
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
db, err := initDB(cfg)
moduleLogger := pkglogger.WithModule(zapLogger, "database")
db, err := initDB(cfg, moduleLogger)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
if err := runMigrations(db, cfg.Driver); err != nil {
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
configurePool(db, cfg)
configurePool(db, cfg, moduleLogger)
return db, nil
}
@@ -40,36 +41,42 @@ func Close(db *gorm.DB) {
sqlDB.Close()
}
func initDB(cfg *config.DatabaseConfig) (*gorm.DB, error) {
func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
gormLogger := pkglogger.NewGormLogger(zapLogger)
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
Logger: gormLogger,
}
switch cfg.Driver {
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
if zapLogger != nil {
zapLogger.Info("连接 MySQL 数据库",
zap.String("host", cfg.Host),
zap.Int("port", cfg.Port),
zap.String("database", cfg.DBName))
}
return gorm.Open(mysql.Open(dsn), gormConfig)
default:
dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil {
if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
if zapLogger != nil {
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
}
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
}
}
func runMigrations(db *gorm.DB, driver string) error {
func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
migrationsDir := getMigrationsDir(driver)
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
gooseDialect := "sqlite3"
migrationsSubDir := "sqlite"
if driver == "mysql" {
@@ -77,19 +84,33 @@ func runMigrations(db *gorm.DB, driver string) error {
migrationsSubDir = "mysql"
}
goose.SetDialect(gooseDialect)
migrationsDir := getMigrationsDir(driver)
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
}
if zapLogger != nil {
zapLogger.Info("执行数据库迁移",
zap.String("dialect", gooseDialect),
zap.String("dir", migrationsSubDir))
}
if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err
}
log.Printf("使用 %s 方言执行迁移,目录: %s", gooseDialect, migrationsSubDir)
return nil
}
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) {
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
if cfg.Driver == "sqlite" {
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
log.Printf("警告: 启用 WAL 模式失败: %v", err)
if zapLogger != nil {
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
}
}
}
@@ -101,8 +122,12 @@ func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) {
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
cfg.MaxIdleConns, cfg.MaxOpenConns, cfg.ConnMaxLifetime)
if zapLogger != nil {
zapLogger.Info("数据库连接池配置",
zap.Int("max_idle_conns", cfg.MaxIdleConns),
zap.Int("max_open_conns", cfg.MaxOpenConns),
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
}
}
func getMigrationsDir(driver string) string {

View File

@@ -4,10 +4,11 @@ import (
"path/filepath"
"testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nex/backend/internal/config"
"go.uber.org/zap"
)
func TestInit_SQLite(t *testing.T) {
@@ -20,7 +21,8 @@ func TestInit_SQLite(t *testing.T) {
ConnMaxLifetime: 0,
}
db, err := Init(cfg, nil)
zapLogger := zap.NewNop()
db, err := Init(cfg, zapLogger)
require.NoError(t, err)
require.NotNil(t, db)
defer Close(db)
@@ -40,7 +42,8 @@ func TestClose(t *testing.T) {
ConnMaxLifetime: 0,
}
db, err := Init(cfg, nil)
zapLogger := zap.NewNop()
db, err := Init(cfg, zapLogger)
require.NoError(t, err)
require.NotNil(t, db)

View File

@@ -13,4 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

View File

@@ -6,13 +6,13 @@ import (
"net/http/httptest"
"testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
)
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
@@ -24,9 +24,9 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{
"id": "p1",
"name": "Test",
"api_key": "sk-test",
"id": "p1",
"name": "Test",
"api_key": "sk-test",
"base_url": "https://api.test.com",
})
w := httptest.NewRecorder()

View File

@@ -9,23 +9,22 @@ import (
"strings"
"testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"gorm.io/gorm"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

View File

@@ -5,9 +5,10 @@ import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
// Logging 日志中间件
func Logging(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
query := c.Request.URL.RawQuery
requestID, _ := c.Get(RequestIDKey)
logger.Info("请求开始",
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("client_ip", c.ClientIP()),
zap.Any("request_id", requestID),
var requestIDStr string
if id, ok := requestID.(string); ok {
requestIDStr = id
}
logger.Debug("请求开始",
pkglogger.Method(c.Request.Method),
pkglogger.Path(path),
pkglogger.Query(query),
pkglogger.ClientIP(c.ClientIP()),
pkglogger.RequestID(requestIDStr),
)
c.Next()
@@ -28,13 +33,13 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
latency := time.Since(start)
statusCode := c.Writer.Status()
logger.Info("请求结束",
zap.Int("status", statusCode),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.Duration("latency", latency),
zap.Int("body_size", c.Writer.Size()),
zap.Any("request_id", requestID),
logger.Debug("请求结束",
pkglogger.StatusCode(statusCode),
pkglogger.Method(c.Request.Method),
pkglogger.Path(path),
pkglogger.Latency(latency),
pkglogger.BodySize(c.Writer.Size()),
pkglogger.RequestID(requestIDStr),
)
}
}

View File

@@ -7,6 +7,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)
func init() {
@@ -65,6 +67,61 @@ func TestLogging(t *testing.T) {
assert.Equal(t, 200, w.Code)
}
func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) {
core, logs := observer.New(zapcore.InfoLevel)
logger := zap.New(core)
w := serveLoggingRequest(logger)
assert.Equal(t, 200, w.Code)
assert.Empty(t, logs.FilterMessage("请求开始").All())
assert.Empty(t, logs.FilterMessage("请求结束").All())
}
func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) {
core, logs := observer.New(zapcore.DebugLevel)
logger := zap.New(core)
w := serveLoggingRequest(logger)
assert.Equal(t, 200, w.Code)
startLogs := logs.FilterMessage("请求开始").All()
endLogs := logs.FilterMessage("请求结束").All()
if assert.Len(t, startLogs, 1) {
fields := startLogs[0].ContextMap()
assert.Equal(t, "GET", fields["method"])
assert.Equal(t, "/test", fields["path"])
assert.Equal(t, "key=value", fields["query"])
assert.Equal(t, "existing-id-123", fields["request_id"])
assert.NotEmpty(t, fields["client_ip"])
}
if assert.Len(t, endLogs, 1) {
fields := endLogs[0].ContextMap()
assert.Equal(t, int64(200), fields["status"])
assert.Equal(t, "GET", fields["method"])
assert.Equal(t, "/test", fields["path"])
assert.Equal(t, int64(2), fields["body_size"])
assert.Equal(t, "existing-id-123", fields["request_id"])
assert.Contains(t, fields, "latency")
}
}
func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder {
r := gin.New()
r.Use(RequestID())
r.Use(Logging(logger))
r.GET("/test", func(c *gin.Context) {
c.String(200, "ok")
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test?key=value", nil)
req.Header.Set("X-Request-ID", "existing-id-123")
r.ServeHTTP(w, req)
return w
}
func TestRecovery_NoPanic(t *testing.T) {
logger := zap.NewNop()

View File

@@ -4,13 +4,13 @@ import (
"errors"
"net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ModelHandler 模型管理处理器
@@ -58,16 +58,16 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
err := h.modelService.Create(model)
if err != nil {
if err == appErrors.ErrProviderNotFound {
if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在",
})
return
}
if err == appErrors.ErrDuplicateModel {
if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
"error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code,
})
return
}
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到",
})

View File

@@ -4,13 +4,13 @@ import (
"errors"
"net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
)
// ProviderHandler 供应商管理处理器
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider)
if err != nil {
if err == appErrors.ErrInvalidProviderID {
if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code,
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
provider, err := h.providerService.Get(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id)
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到",
})

View File

@@ -3,40 +3,44 @@ package handler
import (
"bufio"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/internal/service"
appErrors "nex/backend/pkg/errors"
"nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger"
)
// ProxyHandler 统一代理处理器
type ProxyHandler struct {
engine *conversion.ConversionEngine
client provider.ProviderClient
routingService service.RoutingService
engine *conversion.ConversionEngine
client provider.ProviderClient
routingService service.RoutingService
providerService service.ProviderService
statsService service.StatsService
logger *zap.Logger
statsService service.StatsService
logger *zap.Logger
}
// NewProxyHandler 创建统一代理处理器
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
return &ProxyHandler{
engine: engine,
client: client,
routingService: routingService,
providerService: providerService,
statsService: statsService,
logger: zap.L(),
logger: pkglogger.WithModule(logger, "handler.proxy"),
}
}
@@ -45,7 +49,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol")
if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
return
}
@@ -55,12 +59,13 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
path = "/" + path
}
nativePath := path
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
// 获取 client adapter
registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return
}
@@ -77,7 +82,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
return
}
h.handleModelInfo(c, unifiedID, clientAdapter)
@@ -87,40 +92,50 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
return
}
// 解析统一模型 ID使用 adapter.ExtractModelName
var providerID, modelName string
if len(body) > 0 {
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
}
// 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{
URL: nativePath,
URL: requestPath,
Method: c.Request.Method,
Headers: extractHeaders(c),
Body: body,
}
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err != nil {
if isInvalidJSONError(err) {
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
return
}
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
if providerID == "" || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
// 路由
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
return
}
h.writeError(c, err, clientProtocol)
h.writeRouteError(c, err)
return
}
@@ -137,12 +152,9 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL,
routeResult.Provider.APIKey,
routeResult.Model.ModelName, // 上游模型名,用于请求改写
routeResult.Model.ModelName, // 上游模型名,用于请求改写
)
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID()
@@ -153,12 +165,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
}
}
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
return true
default:
return false
}
}
func isInvalidJSONError(err error) bool {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}
// handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
return
}
@@ -166,31 +200,27 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
h.writeConversionError(c, err, clientProtocol)
h.logger.Error("发送请求失败", zap.Error(err))
h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return
}
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol)
return
}
// 设置响应头
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
h.writeConvertedResponse(c, *convertedResp)
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}()
}
@@ -203,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return
}
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
// 发送流式请求
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
h.writeUpstreamUnavailable(c, err)
return
}
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
}
// 发送流式请求
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
@@ -222,37 +260,61 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range eventChan {
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
h.logger.Error("流读取错误", zap.Error(event.Error))
break
}
if event.Done {
// flush 转换器
chunks := streamConverter.Flush()
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks {
writer.Write(chunk)
writer.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
}
}
go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}()
}
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat {
return false
}
@@ -271,8 +333,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels()
if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
h.logger.Error("查询启用模型失败", zap.Error(err))
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
return
}
@@ -293,8 +355,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList)
if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
h.logger.Error("编码 Models 响应失败", zap.Error(err))
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return
}
@@ -306,17 +368,14 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
return
}
// 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
return
}
@@ -331,42 +390,104 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// 使用 adapter 编码返回
body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return
}
c.Data(http.StatusOK, "application/json", body)
}
// writeConversionError 写入转换错误
// writeConversionError 写入网关层转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok {
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
c.Data(statusCode, "application/json", body)
var convErr *conversion.ConversionError
if errors.As(err, &convErr) {
statusCode, code, message := mapConversionError(convErr)
h.writeProxyError(c, statusCode, code, message)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
}
// writeError 写入路由错误
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
switch err.Code {
case conversion.ErrorCodeJSONParseError:
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
}
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
case conversion.ErrorCodeInvalidInput,
conversion.ErrorCodeMissingRequiredField,
conversion.ErrorCodeProtocolConstraint:
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
case conversion.ErrorCodeInterfaceNotSupported:
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
case conversion.ErrorCodeUnsupportedMultimodal:
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
default:
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
}
}
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
switch appErr.Code {
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
default:
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
}
return
}
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
}
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
h.logger.Error("上游不可达", zap.Error(err))
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
}
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": message,
"code": code,
})
}
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range resp.Headers {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, resp.Body)
}
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range filterHopByHopHeaders(resp.Headers) {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
c.Data(resp.StatusCode, contentType, resp.Body)
}
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return
}
providers, err := h.providerService.List()
if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
return
}
@@ -376,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
providerProtocol = "openai"
}
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json"
}
outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL,
URL: joinBaseURL(p.BaseURL, upstreamPath),
Method: inSpec.Method,
Headers: headers,
Body: inSpec.Body,
@@ -401,9 +521,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
}
}
if isStream {
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
return
}
resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return
}
@@ -413,13 +542,111 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
return
}
for k, v := range convertedResp.Headers {
c.Header(k, v)
h.writeConvertedResponse(c, *convertedResp)
}
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
if err != nil {
h.writeUpstreamUnavailable(c, err)
return
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("透传流读取错误", zap.Error(event.Error))
break
}
if event.Done {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
chunks := streamConverter.ProcessChunk(event.Data)
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
}
}
func stripRawQuery(path string) string {
pathOnly, _, _ := strings.Cut(path, "?")
return pathOnly
}
func rawQueryFromPath(path string) string {
_, rawQuery, found := strings.Cut(path, "?")
if !found {
return ""
}
return rawQuery
}
func joinBaseURL(baseURL, path string) string {
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func headerValue(headers map[string]string, key string) string {
for k, v := range headers {
if strings.EqualFold(k, key) {
return v
}
}
return ""
}
func filterHopByHopHeaders(headers map[string]string) map[string]string {
if len(headers) == 0 {
return nil
}
hopByHop := map[string]struct{}{
"connection": {},
"transfer-encoding": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"upgrade": {},
}
filtered := make(map[string]string, len(headers))
for k, v := range headers {
if _, skip := hopByHop[strings.ToLower(k)]; skip {
continue
}
filtered[k] = v
}
return filtered
}
// extractHeaders 从 Gin context 提取请求头

View File

@@ -5,33 +5,34 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"go.uber.org/zap"
appErrors "nex/backend/pkg/errors"
)
func init() {
gin.SetMode(gin.TestMode)
}
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper()
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter()))
require.NoError(t, registry.Register(anthropic.NewAdapter()))
return engine
@@ -44,6 +45,7 @@ func newTestProxyHandler(engine *conversion.ConversionEngine, client *mocks.Mock
routingSvc,
providerSvc,
statsSvc,
zap.NewNop(),
)
}
@@ -72,7 +74,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -91,8 +93,8 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -108,20 +110,20 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(nil, appErrors.ErrModelNotFound)
routingSvc.EXPECT().RouteByModelName("unknown", "model").Return(nil, appErrors.ErrModelNotFound)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return(nil, nil)
client := mocks.NewMockProviderClient(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
}
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
@@ -130,7 +132,7 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -144,11 +146,12 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
}
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
@@ -157,7 +160,7 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -171,11 +174,12 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
}
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
@@ -184,12 +188,12 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
@@ -198,7 +202,7 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -207,13 +211,14 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Contains(t, w.Body.String(), "Hello")
assert.Contains(t, w.Body.String(), "p1/gpt-4")
}
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
@@ -222,12 +227,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
return nil, context.DeadlineExceeded
})
providerSvc := mocks.NewMockProviderService(ctrl)
@@ -236,11 +241,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
}
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
@@ -260,8 +266,8 @@ func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -281,11 +287,11 @@ func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/models"}}
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
h.HandleProxy(c)
assert.Equal(t, 400, w.Code)
assert.Equal(t, 404, w.Code)
}
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
@@ -303,8 +309,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -328,7 +334,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -346,8 +352,8 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -370,6 +376,7 @@ func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
h.writeConversionError(c, context.DeadlineExceeded, "openai")
assert.Equal(t, 500, w.Code)
assert.JSONEq(t, `{"error":"context deadline exceeded","code":"CONVERSION_FAILED"}`, w.Body.String())
}
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
@@ -389,7 +396,40 @@ func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
h.writeConversionError(c, convErr, "openai")
assert.Equal(t, 500, w.Code)
assert.Equal(t, 400, w.Code)
assert.JSONEq(t, `{"error":"bad request","code":"INVALID_REQUEST"}`, w.Body.String())
}
func TestProxyHandler_WriteConversionError_JSONPhase(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
t.Run("request json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewRequestJSONParseError("解码请求失败", context.Canceled), "openai")
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
})
t.Run("response json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewResponseJSONParseError("解码响应失败", context.Canceled), "openai")
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"error":"解码响应失败","code":"CONVERSION_FAILED"}`, w.Body.String())
})
}
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
@@ -409,8 +449,8 @@ func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -422,19 +462,19 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -443,8 +483,8 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -459,12 +499,12 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
@@ -472,7 +512,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -481,8 +521,8 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -499,12 +539,12 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
defer ctrl.Finish()
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
engine := conversion.NewConversionEngine(registry, zap.NewNop())
err := registry.Register(openai.NewAdapter())
require.NoError(t, err)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -515,8 +555,8 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
@@ -527,11 +567,11 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
defer ctrl.Finish()
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -542,8 +582,8 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
@@ -554,12 +594,12 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
defer ctrl.Finish()
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter()))
require.NoError(t, registry.Register(anthropic.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
}, nil)
@@ -577,8 +617,8 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 500, w.Code)
@@ -590,7 +630,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
@@ -609,8 +649,8 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -623,7 +663,7 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
defer ctrl.Finish()
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, nil)
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(openai.NewAdapter()))
anthropicAdapter := anthropic.NewAdapter()
@@ -641,8 +681,8 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -665,8 +705,8 @@ func TestProxyHandler_ForwardPassthrough_NoBody_NoModel(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -689,10 +729,10 @@ func TestIsStreamRequest_EdgeCases(t *testing.T) {
path string
expected bool
}{
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/chat/completions", true},
{"stream with spaces", `{"stream" : true}`, "/chat/completions", true},
{"stream embedded in string value", `{"model":"stream:true"}`, "/chat/completions", false},
{"empty body", "", "/chat/completions", false},
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/v1/chat/completions", true},
{"stream with spaces", `{"stream" : true}`, "/v1/chat/completions", true},
{"stream embedded in string value", `{"model":"stream:true"}`, "/v1/chat/completions", false},
{"empty body", "", "/v1/chat/completions", false},
{"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false},
}
@@ -719,8 +759,9 @@ func TestProxyHandler_WriteError_RouteError(t *testing.T) {
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeError(c, fmt.Errorf("model not found"), "openai")
h.writeRouteError(c, fmt.Errorf("model not found"))
assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
}
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
@@ -740,8 +781,8 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -764,35 +805,35 @@ func TestIsStreamRequest(t *testing.T) {
name: "stream true",
body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai",
nativePath: "/chat/completions",
nativePath: "/v1/chat/completions",
expected: true,
},
{
name: "stream false",
body: []byte(`{"model": "gpt-4", "stream": false}`),
clientProtocol: "openai",
nativePath: "/chat/completions",
nativePath: "/v1/chat/completions",
expected: false,
},
{
name: "no stream field",
body: []byte(`{"model": "gpt-4"}`),
clientProtocol: "openai",
nativePath: "/chat/completions",
nativePath: "/v1/chat/completions",
expected: false,
},
{
name: "invalid json",
body: []byte(`{invalid}`),
clientProtocol: "openai",
nativePath: "/chat/completions",
nativePath: "/v1/chat/completions",
expected: false,
},
{
name: "not chat endpoint",
body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai",
nativePath: "/models",
nativePath: "/v1/models",
expected: false,
},
{
@@ -830,8 +871,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -842,7 +883,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
require.True(t, ok)
assert.Len(t, data, 2)
first := data[0].(map[string]interface{})
first, ok2 := data[0].(map[string]interface{})
require.True(t, ok2)
assert.Equal(t, "openai/gpt-4", first["id"])
}
@@ -860,8 +902,8 @@ func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}}
c.Request = httptest.NewRequest("GET", "/openai/models/openai/gpt-4", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/openai/gpt-4"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -894,8 +936,8 @@ func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testi
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}}
c.Request = httptest.NewRequest("GET", "/openai/models/", nil)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/"}}
c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -916,7 +958,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var req map[string]interface{}
json.Unmarshal(spec.Body, &req)
require.NoError(t, json.Unmarshal(spec.Body, &req))
assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{
@@ -932,8 +974,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -970,8 +1012,8 @@ func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -992,7 +1034,7 @@ func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10)
go func() {
defer close(ch)
@@ -1010,7 +1052,7 @@ data: {"type":"message_stop"}
`)}
ch <- provider.StreamEvent{Done: true}
}()
return ch, nil
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
@@ -1019,8 +1061,8 @@ data: {"type":"message_stop"}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -1057,8 +1099,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
h.HandleProxy(c)
assert.Equal(t, 200, w.Code)
@@ -1088,8 +1130,8 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
assert.Equal(t, 404, w.Code)
@@ -1098,3 +1140,314 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "error")
}
func TestProxyHandler_HandleProxy_OpenAIAndAnthropicNativePaths(t *testing.T) {
tests := []struct {
name string
protocol string
path string
requestPath string
baseURL string
expectedURL string
body string
responseBody string
responseModel string
}{
{
name: "openai path keeps v1 after gateway prefix",
protocol: "openai",
path: "/v1/chat/completions",
requestPath: "/openai/v1/chat/completions",
baseURL: "https://api.test.com/v1",
expectedURL: "https://api.test.com/v1/chat/completions",
body: `{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`,
responseModel: "p1/gpt-4",
},
{
name: "anthropic path keeps v1 after gateway prefix",
protocol: "anthropic",
path: "/v1/messages",
requestPath: "/anthropic/v1/messages",
baseURL: "https://api.anthropic.test",
expectedURL: "https://api.anthropic.test/v1/messages",
body: `{"model":"p1/gpt-4","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"msg-1","type":"message","role":"assistant","model":"gpt-4","content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`,
responseModel: "p1/gpt-4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: tt.baseURL, Protocol: tt.protocol, Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, tt.expectedURL, spec.URL)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(tt.responseBody),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: tt.protocol}, {Key: "path", Value: tt.path}}
c.Request = httptest.NewRequest("POST", tt.requestPath, bytes.NewReader([]byte(tt.body)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), tt.responseModel)
})
}
}
func TestProxyHandler_UpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).Return(&conversion.HTTPResponseSpec{
StatusCode: http.StatusTooManyRequests,
Headers: map[string]string{
"Content-Type": "application/json",
"X-Upstream-Error": "rate-limit",
"Transfer-Encoding": "chunked",
},
Body: []byte(`{"error":{"message":"rate limited"}}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusTooManyRequests, w.Code)
assert.JSONEq(t, `{"error":{"message":"rate limited"}}`, w.Body.String())
assert.Equal(t, "rate-limit", w.Header().Get("X-Upstream-Error"))
assert.Empty(t, w.Header().Get("Transfer-Encoding"))
}
func TestProxyHandler_StreamUpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).Return(&provider.StreamResponse{
StatusCode: http.StatusServiceUnavailable,
Headers: map[string]string{"Content-Type": "application/json", "Connection": "close"},
Body: []byte(`{"error":"upstream down"}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusServiceUnavailable, w.Code)
assert.JSONEq(t, `{"error":"upstream down"}`, w.Body.String())
assert.Empty(t, w.Header().Get("Connection"))
}
func TestFilterHopByHopHeaders(t *testing.T) {
filtered := filterHopByHopHeaders(map[string]string{
"Connection": "close",
"Transfer-Encoding": "chunked",
"Keep-Alive": "timeout=5",
"Proxy-Authenticate": "Basic",
"Proxy-Authorization": "Basic token",
"TE": "trailers",
"Trailer": "Expires",
"Upgrade": "websocket",
"Content-Type": "application/json",
"X-Request-ID": "req-1",
})
assert.Equal(t, map[string]string{
"Content-Type": "application/json",
"X-Request-ID": "req-1",
}, filtered)
}
func TestProxyHandler_UnknownInterface_DoesNotGuessModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, "https://api.test.com/v1/unknown?trace=1", spec.URL)
assert.JSONEq(t, `{"model":"p1/gpt-4","payload":true}`, string(spec.Body))
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"ok":true}`),
}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/unknown"}}
c.Request = httptest.NewRequest("POST", "/openai/unknown?trace=1", bytes.NewReader([]byte(`{"model":"p1/gpt-4","payload":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{"ok":true}`, w.Body.String())
}
func TestProxyHandler_InvalidJSON_UsesGatewayError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":`)))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
}
func TestProxyHandler_CrossProtocolMultimodal_Unsupported(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("anthropic_p", "claude").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.test", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"anthropic_p/claude","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "UNSUPPORTED_MULTIMODAL")
}
func TestProxyHandler_SameProtocolMultimodal_SmartPassthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Contains(t, string(spec.Body), "image_url")
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "p1/gpt-4")
}
func TestProxyHandler_RawStreamPassthrough_PreservesSSEFrames(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
ch := make(chan provider.StreamEvent, 3)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"model\":\"gpt-4\",\"choices\":[]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return &provider.StreamResponse{StatusCode: http.StatusOK, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "data: {\"model\":\"gpt-4\",\"choices\":[]}\n\ndata: [DONE]\n\n", w.Body.String())
}

View File

@@ -5,9 +5,9 @@ import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service"
"github.com/gin-gonic/gin"
)
// StatsHandler 统计处理器

View File

@@ -0,0 +1,26 @@
package handler
import (
"net/http"
"nex/backend/pkg/buildinfo"
"github.com/gin-gonic/gin"
)
// VersionHandler 提供后端构建版本信息。
type VersionHandler struct{}
// NewVersionHandler 创建版本信息处理器。
func NewVersionHandler() *VersionHandler {
return &VersionHandler{}
}
// GetVersion 返回构建注入的版本元数据。
func (h *VersionHandler) GetVersion(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"version": buildinfo.Version(),
"commit": buildinfo.Commit(),
"build_time": buildinfo.BuildTime(),
})
}

View File

@@ -0,0 +1,31 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVersionHandler_GetVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewVersionHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/version", nil)
h.GetVersion(c)
assert.Equal(t, http.StatusOK, w.Code)
var result map[string]string
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "dev", result["version"])
assert.Equal(t, "unknown", result["commit"])
assert.Equal(t, "unknown", result["build_time"])
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net"
"net/http"
"strings"
"syscall"
"time"
@@ -15,6 +16,7 @@ import (
"nex/backend/internal/conversion"
pkgErrors "nex/backend/pkg/errors"
pkglogger "nex/backend/pkg/logger"
)
// StreamConfig 流式处理配置
@@ -42,6 +44,14 @@ type StreamEvent struct {
Done bool
}
// StreamResponse 表示上游流式 HTTP 响应。
type StreamResponse struct {
StatusCode int
Headers map[string]string
Body []byte
Events <-chan StreamEvent
}
// Client 协议无关的供应商客户端
type Client struct {
httpClient *http.Client
@@ -50,19 +60,20 @@ type Client struct {
}
// ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
}
// NewClient 创建供应商客户端
func NewClient() *Client {
func NewClient(logger *zap.Logger) *Client {
return &Client{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
logger: zap.L(),
logger: pkglogger.WithModule(logger, "provider.client"),
streamCfg: DefaultStreamConfig(),
}
}
@@ -114,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
}
// SendStream 发送流式请求
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
var bodyReader io.Reader
if len(spec.Body) > 0 {
bodyReader = bytes.NewReader(spec.Body)
@@ -137,20 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
return nil, pkgErrors.ErrRequestSend.WithCause(err)
}
if resp.StatusCode != http.StatusOK {
respHeaders := extractResponseHeaders(resp.Header)
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close()
cancel()
errBody, _ := io.ReadAll(resp.Body)
if len(errBody) > 0 {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
errBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
}
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Body: errBody,
}, nil
}
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
return eventChan, nil
return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Events: eventChan,
}, nil
}
// readStream 读取 SSE 流
@@ -183,10 +203,10 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil {
if err != io.EOF {
if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error()))
c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else {
c.logger.Error("流读取错误", zap.String("error", err.Error()))
c.logger.Error("流读取错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
}
return
@@ -203,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
}
for {
idx := bytes.Index(dataBuf, []byte("\n\n"))
idx, sepLen := findSSEFrameSeparator(dataBuf)
if idx == -1 {
break
}
rawEvent := dataBuf[:idx]
dataBuf = dataBuf[idx+2:]
frameEnd := idx + sepLen
rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
dataBuf = dataBuf[frameEnd:]
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
if isSSEDoneFrame(rawEvent) {
eventChan <- StreamEvent{Data: rawEvent}
eventChan <- StreamEvent{Done: true}
return
}
@@ -220,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
}
if err == io.EOF {
if len(dataBuf) > 0 {
eventChan <- StreamEvent{Data: dataBuf}
}
return
}
}
}
func isSSEDoneFrame(frame []byte) bool {
payload, ok := sseFrameDataPayload(frame)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func extractResponseHeaders(header http.Header) map[string]string {
respHeaders := make(map[string]string)
for k, vs := range header {
if len(vs) > 0 {
respHeaders[k] = vs[0]
}
}
return respHeaders
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
// isNetworkError 判断是否为网络相关错误
func isNetworkError(err error) bool {
if err == nil {

View File

@@ -13,12 +13,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"nex/backend/internal/conversion"
)
func TestNewClient(t *testing.T) {
client := NewClient()
client := NewClient(zap.NewNop())
require.NotNil(t, client)
assert.NotNil(t, client.httpClient)
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
@@ -40,11 +41,12 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
_, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -64,11 +66,12 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
_, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -82,7 +85,7 @@ func TestClient_Send_ErrorResponse(t *testing.T) {
}
func TestClient_Send_ConnectionError(t *testing.T) {
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: "http://localhost:1/v1/chat/completions",
Method: "POST",
@@ -99,7 +102,7 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -107,11 +110,13 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, eventChan)
require.NotNil(t, streamResp)
require.Equal(t, http.StatusOK, streamResp.StatusCode)
require.NotNil(t, streamResp.Events)
for range eventChan {
for range streamResp.Events {
}
}
@@ -121,7 +126,7 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -129,8 +134,10 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
Body: []byte(`{}`),
}
_, err := client.SendStream(context.Background(), spec)
assert.Error(t, err)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
assert.Equal(t, http.StatusInternalServerError, streamResp.StatusCode)
}
func TestClient_SendStream_SSEEvents(t *testing.T) {
@@ -139,18 +146,21 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
_, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -158,24 +168,73 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte
var doneEvents int
for event := range eventChan {
if event.Done {
for event := range streamResp.Events {
switch {
case event.Done:
doneEvents++
} else if event.Error != nil {
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
} else {
default:
dataEvents = append(dataEvents, event.Data)
}
}
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream")
assert.Equal(t, 3, len(dataEvents), "expected 2 data frames plus DONE frame from SSE stream")
assert.Contains(t, string(dataEvents[0]), "Hello")
assert.Contains(t, string(dataEvents[1]), "World")
assert.Contains(t, string(dataEvents[2]), "[DONE]")
assert.Equal(t, 1, doneEvents)
assert.Contains(t, string(dataEvents[0]), "\n\n")
}
func TestClient_SendStream_DONEOnlyWhenDataPayloadEqualsDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
_, err := w.Write([]byte("data: {\"text\":\"data: [DONE] is plain text\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Body: []byte(`{}`),
}
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte
var doneEvents int
for event := range streamResp.Events {
switch {
case event.Done:
doneEvents++
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
default:
dataEvents = append(dataEvents, event.Data)
}
}
require.Len(t, dataEvents, 2)
assert.Contains(t, string(dataEvents[0]), "plain text")
assert.Contains(t, string(dataEvents[1]), "[DONE]")
assert.Equal(t, 1, doneEvents)
}
@@ -188,7 +247,7 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -196,13 +255,13 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(ctx, spec)
streamResp, err := client.SendStream(ctx, spec)
require.NoError(t, err)
cancel()
var gotError bool
for event := range eventChan {
for event := range streamResp.Events {
if event.Error != nil {
gotError = true
}
@@ -214,11 +273,12 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`))
_, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/models",
Method: "GET",
@@ -237,16 +297,18 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(100 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -254,21 +316,22 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataCount int
var doneCount int
for event := range eventChan {
if event.Done {
for event := range streamResp.Events {
switch {
case event.Done:
doneCount++
} else if event.Error != nil {
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
} else {
default:
dataCount++
}
}
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE")
assert.Equal(t, 2, dataCount, "expected 1 data frame plus DONE frame from slow SSE")
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
}
@@ -278,16 +341,18 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n"))
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -295,19 +360,19 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var dataEvents int
var doneEvents int
for event := range eventChan {
for event := range streamResp.Events {
if event.Done {
doneEvents++
} else {
dataEvents++
}
}
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE")
assert.Equal(t, 3, dataEvents, "expected 2 data frames plus DONE frame from split SSE")
assert.Equal(t, 1, doneEvents)
}
@@ -363,19 +428,20 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack()
if conn != nil {
conn.Close()
require.NoError(t, conn.Close())
}
}
}))
defer server.Close()
client := NewClient()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
@@ -383,11 +449,11 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
Body: []byte(`{}`),
}
eventChan, err := client.SendStream(context.Background(), spec)
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
var gotData bool
for event := range eventChan {
for event := range streamResp.Events {
if event.Error != nil {
} else if !event.Done {
gotData = true

View File

@@ -3,10 +3,11 @@ package repository
import (
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
)

View File

@@ -3,13 +3,13 @@ package repository
import (
"testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
)
func setupTestDB(t *testing.T) *gorm.DB {

View File

@@ -1,13 +1,13 @@
package repository
import (
"errors"
"time"
"gorm.io/gorm"
"nex/backend/internal/config"
"nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type statsRepository struct {
@@ -19,50 +19,46 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
}
func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02")
todayTime, _ := time.Parse("2006-01-02", today)
now := time.Now()
todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, todayTime).First(&stats).Error
stats := config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
if errors.Is(err, gorm.ErrRecordNotFound) {
stats = config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: 1,
Date: todayTime,
}
return tx.Create(&stats).Error
} else if err != nil {
return err
}
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
})
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + 1"),
}),
}).Create(&stats).Error
}
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
return r.db.Transaction(func(tx *gorm.DB) error {
var stats config.UsageStats
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
providerID, modelName, date).First(&stats).Error
stats := config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return tx.Create(&config.UsageStats{
ProviderID: providerID,
ModelName: modelName,
RequestCount: delta,
Date: date,
}).Error
} else if err != nil {
return err
}
return tx.Model(&stats).
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
})
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "provider_id"},
{Name: "model_name"},
{Name: "date"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"request_count": gorm.Expr("request_count + ?", delta),
}),
}).Create(&stats).Error
}
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {

View File

@@ -1,11 +1,15 @@
package service
import (
"github.com/google/uuid"
appErrors "nex/backend/pkg/errors"
"errors"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
)
type modelService struct {
@@ -108,7 +112,11 @@ func (s *modelService) Delete(id string) error {
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil {
return nil // 未找到,不重复
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复
}
return err
}
if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身

View File

@@ -3,10 +3,10 @@ package service
import (
"strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain"
"nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors"
)

Some files were not shown because too many files have changed in this diff Show More