1
0

33 Commits

Author SHA1 Message Date
235efb0e62 chore: 版本升迁 v0.1.1 2026-05-05 04:39:31 +08:00
6b1af27ea2 fix: 移除 version-bump 的工作区干净检查 2026-05-05 04:39:21 +08:00
32f48777f3 feat: make version-bump 默认 BUMP=patch,无需显式传参 2026-05-05 04:32:38 +08:00
bc7a7c6e81 feat: 迁移 versionctl 为独立模块并新增 make version-bump 命令
- 将 backend/cmd/versionctl 和 backend/pkg/projectversion 迁移至独立 versionctl/ Go 模块
- 新增 bump 子命令支持 major/minor/patch 和指定版本号,含版本倒退防护
- 新增 make version-bump 编排完整升迁流程(bump + sync + check + commit + tag)
- 更新所有引用路径:根 Makefile、backend/Makefile、release.yml、.golangci.yml
- 新增 versionctl/.golangci.yml(精简配置)和 Makefile(lint/test/coverage)
- 根 Makefile lint/test 集成 versionctl 模块
- 同步 openspec specs:新增 version-bump spec,更新 release-pipeline spec
2026-05-05 04:18:10 +08:00
3cd0458c2c fix: 修复 golangci-lint 报告的 gosec/gocyclo/forbidigo 问题 2026-05-05 03:35:20 +08:00
8eea30ea11 feat: 统一品牌标识、关于页面三卡片布局与版本诊断功能
- 统一品牌为 Nex:侧边栏、托盘 tooltip、HTML 标题、favicon (PNG 替代 SVG)
- 重构关于页面为三卡片布局(品牌/版本/链接),版本状态 Tag 绝对定位右上角
- 新增 GET /api/version 后端接口,返回 version/commit/build_time
- 新增前端版本一致性诊断:匹配/不匹配/不可判断三种状态
- 同步 delta specs 到主 specs 并归档变更
2026-05-05 03:28:22 +08:00
9e33e570af fix: 降低请求生命周期日志级别 2026-05-05 01:54:53 +08:00
7653385838 fix: 加固发布流水线运行环境
修复 Windows 发布作业在 MSYS2 环境下无法访问 Go 工具链的问题。

为三平台发布增加工具链预检并升级 release workflow 运行时兼容性,减少版本检查噪音和 CI 告警。
2026-05-05 01:27:38 +08:00
2c401f7ae6 chore: streamline workspace make workflows
Clarify product-level server and desktop commands while moving backend-only maintenance tasks into backend/Makefile. This keeps root automation focused on core flows and aligns the main OpenSpec specs with the new command boundaries.
2026-04-28 17:44:23 +08:00
a9972360c2 feat: 增加版本化构建与发布流程
引入 VERSION 作为统一版本源,避免前端、后端、桌面打包和发布资产之间的版本漂移。
新增 tag 驱动的 Draft Release 流程与版本化资产命名,使本地演进和 GitHub 发布共享同一套约束。
2026-04-28 14:20:27 +08:00
b00fa4dcee chore: 调整开源许可证为 Apache 2.0
新增根目录 Apache 2.0 许可证文件,并同步更新仓库文档与前端包元数据声明。
2026-04-27 11:24:33 +08:00
92525b39c3 chore: 将 assets 图标文件迁移到 Git LFS 2026-04-27 10:34:38 +08:00
38a2555c7b fix: Anthropic 流式编码器补全 message_start/message_delta 必填字段
跨协议流式转换时,Anthropic 客户端 Zod 校验因 SSE 事件缺少必填字段报错。
由 Anthropic encoder 层(而非 OpenAI decoder 层)负责补全协议默认值,保持权责分离。

- encodeMessageStart 补全 type/content/stop_reason/stop_sequence,usage nil 时输出零值
- encodeMessageDelta usage nil 时输出零值
- 更新相关测试覆盖新增行为
2026-04-26 23:27:34 +08:00
9622d44aac fix: 完善转换代理行为 2026-04-26 21:48:17 +08:00
155244433f fix: 完善桌面应用图标打包
统一 macOS 图标命名为 icon.icns。\n补充 Linux hicolor 图标资源。\n修复 Windows make 构建兼容性并为 exe 嵌入图标资源。\n清理旧版图标说明与不再使用的 SVG 源文件。
2026-04-26 12:24:03 +08:00
2c043c6cf7 fix: 修正 conversion 代理路径和错误边界 2026-04-25 23:12:54 +08:00
f5c82b6980 chore: 合并 dev-test-config 到 master 2026-04-24 23:23:12 +08:00
9105a36097 feat: 将"关于"从系统托盘原生对话框迁移到前端页面
移除系统托盘右键菜单中的"关于"选项及各平台原生对话框实现,
在前端新增 /about 路由和关于页面展示品牌信息,侧边栏增加关于导航入口
2026-04-24 23:17:22 +08:00
f1ee646ca4 fix: 修复 TestSaveAndLoadConfig 测试隔离问题,使用临时目录替代真实用户配置 2026-04-24 22:59:26 +08:00
b9b487c591 chore: 统一项目编辑器配置,移至仓库根目录 2026-04-24 22:28:01 +08:00
4c62c071fb fix: 修复 macOS 桌面应用打包与元数据
将 macOS 桌面应用改为通用二进制并动态写入最低系统版本,避免 Intel Mac 无法启动。统一桌面应用名称与托盘展示,并补充测试确保相关行为稳定。
2026-04-24 19:35:51 +08:00
b2e9dd8b7f refactor: 合并 macOS 桌面打包流程
将 macOS .app 打包直接并入 desktop-build-mac,减少重复的桌面构建入口。\n\n同时移除未实现或已废弃的 desktop-package-* 命令和独立打包脚本,降低维护成本。
2026-04-24 19:02:21 +08:00
d143c5f3df fix: 补齐前端生成物忽略并消除构建告警
统一 Git、ESLint、Prettier 对测试和构建生成物的忽略规则,避免本地产物导致 frontend-build 失败。

补齐表单 effect 依赖,移除无关告警,让前端构建链路恢复稳定。
2026-04-24 18:53:53 +08:00
4eebdfb8db chore: 合并 dev-code-format-frontend 到 master 2026-04-24 18:21:27 +08:00
b517946585 chore: 合并 dev-code-backend-format 到 master 2026-04-24 18:20:12 +08:00
4ddae6be74 refactor: 优化提示词文档,增强智能合并与规范审查能力 2026-04-24 18:12:23 +08:00
195762ff97 fix: 修复后端配置加载测试失败
- 修复 viper SafeWriteConfig 与 SetConfigFile 不兼容问题
  - 将 SafeWriteConfig() 替换为 SafeWriteConfigAs(configPath)
  - 绕过 viper 的 configPaths 检查
- 调整 Makefile 测试命令分类
  - backend-test: 仅运行后端核心测试
  - backend-test-all: 运行全部后端测试(含 desktop)
  - desktop-test: 单独运行桌面应用测试
- 同步 config-management 和 test-coverage 规范
2026-04-24 14:06:03 +08:00
bcf5ca89e5 refactor: Makefile 前端命令自动安装依赖 2026-04-24 13:50:51 +08:00
365943e4c4 feat: 前端集成 Prettier 代码格式化 2026-04-24 13:40:53 +08:00
4c6b49099d feat: 配置 golangci-lint 静态分析并修复存量违规
- 新增 backend/.golangci.yml 配置 12 个 linter(forbidigo、errorlint、errcheck、staticcheck、revive、gocritic、gosec、bodyclose、noctx、nilerr、goimports、gocyclo)
- 新增 lefthook.yml 配置 pre-commit hook 自动运行 lint
- 修复存量代码违规:errors.Is/As 替换、zap.Error 替换、import 排序、errcheck 修复
- 更新 README 补充编码规范说明
- 归档 backend-code-lint 变更
2026-04-24 13:01:48 +08:00
4c78ab6cc8 chore: 新增 backend-code-lint 变更计划,更新 .gitignore 2026-04-24 00:19:56 +08:00
52007c9461 feat: 前端 ESLint 规则增强,自动检测 LLM 编码违规
- 启用 TanStack Query flat/recommended(7 条规则)
- 新增 no-console(允许 warn/error)、consistent-type-imports(inline 风格)、no-non-null-assertion 规则
- 新增自定义规则 no-hardcoded-color-in-style,检测 JSX style 中硬编码颜色值
- 将 ESLint 检查集成到 build 命令(tsc -b && eslint . && vite build)
- 修复现有代码中的 lint 违规(import 顺序、type import 风格、unused vars)
- 使用 @typescript-eslint/rule-tester 编写自定义规则集成测试
2026-04-23 22:47:32 +08:00
086dd1fed7 refactor: 重构智能合并提示词,优化交互体验并增加语义审查能力
- 三阶段模型:计划审批(1次) → 自动合并(仅异常中断) → 汇总收尾(1次),减少冗余介入
- 新增语义审查:合并后自动分析代码冗余/过时模式/未集成基础设施/风格不一致
- 完善安全锚点体系:全局锚点+分支级锚点,分级回退机制
- 修复多个逻辑缺陷:stash含未跟踪文件、分支匹配过宽、远端分支拉取重名、语义修复交叉污染等
- 明确提问工具选项和交互方式,消除歧义引导
2026-04-23 22:35:24 +08:00
250 changed files with 10986 additions and 3788 deletions

7
.editorconfig Normal file
View File

@@ -0,0 +1,7 @@
root = true
[*]
end_of_line = lf
trim_trailing_whitespace = true
insert_final_newline = true
charset = utf-8

9
.gitattributes vendored Normal file
View File

@@ -0,0 +1,9 @@
* text=auto eol=lf
assets/*.png filter=lfs diff=lfs merge=lfs -text
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
assets/*.icns filter=lfs diff=lfs merge=lfs -text
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
assets/*.ico filter=lfs diff=lfs merge=lfs -text
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text
frontend/public/*.png filter=lfs diff=lfs merge=lfs -text
frontend/public/**/*.png filter=lfs diff=lfs merge=lfs -text

210
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,210 @@
name: Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: read
jobs:
prepare:
name: Prepare Release
runs-on: ubuntu-latest
permissions:
contents: read
outputs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Verify tag and VERSION
id: version
run: |
version=$(go run ./versionctl print)
go run ./versionctl verify-tag "${GITHUB_REF_NAME}"
printf 'version=%s\n' "$version" >> "$GITHUB_OUTPUT"
build-linux:
name: Build Linux Assets
needs: prepare
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Install Linux desktop build dependencies
run: |
sudo apt-get update
sudo apt-get install -y libayatana-appindicator3-dev libgtk-3-dev
- name: Preflight Linux release toolchain
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v gcc
gcc --version
command -v pkg-config
pkg-config --modversion ayatana-appindicator3-0.1
pkg-config --modversion gtk+-3.0
- name: Build Linux release assets
run: make release-assets-linux
- name: Upload Linux release assets
uses: actions/upload-artifact@v4
with:
name: release-linux
path: build/release/*
build-windows:
name: Build Windows Assets
needs: prepare
runs-on: windows-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Setup MSYS2 toolchain
uses: msys2/setup-msys2@v2
with:
msystem: MINGW64
path-type: inherit
update: true
install: >-
make
mingw-w64-x86_64-gcc
- name: Preflight Windows release toolchain
shell: msys2 {0}
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v make
make --version
command -v gcc
gcc --version
command -v windres
windres --version
if command -v powershell.exe >/dev/null 2>&1; then
powershell.exe -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
else
command -v powershell
powershell -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
fi
- name: Build Windows release assets
shell: msys2 {0}
run: make release-assets-windows
- name: Upload Windows release assets
uses: actions/upload-artifact@v4
with:
name: release-windows
path: build/release/*
build-macos:
name: Build macOS Assets
needs: prepare
runs-on: macos-latest
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.work
cache-dependency-path: |
backend/go.sum
versionctl/go.sum
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Preflight macOS release toolchain
run: |
set -euo pipefail
command -v go
go version
command -v bun
bun --version
command -v ditto
xcrun --find lipo
xcrun --find vtool
- name: Build macOS release assets
run: make release-assets-macos
- name: Upload macOS release assets
uses: actions/upload-artifact@v4
with:
name: release-macos
path: build/release/*
draft-release:
name: Create Draft Release
needs: [prepare, build-linux, build-windows, build-macos]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download release assets
uses: actions/download-artifact@v4
with:
pattern: release-*
merge-multiple: true
path: dist
- name: Publish draft release
uses: softprops/action-gh-release@v2
with:
name: v${{ needs.prepare.outputs.version }}
tag_name: ${{ github.ref_name }}
draft: true
files: |
dist/*

3
.gitignore vendored
View File

@@ -401,13 +401,16 @@ cython_debug/
# Custom # Custom
.claude .claude
.opencode .opencode
.codex
openspec/changes/archive openspec/changes/archive
temp temp
.agents .agents
skills-lock.json skills-lock.json
.worktrees .worktrees
!scripts/build/ !scripts/build/
backend/bin
# Embedfs generated # Embedfs generated
embedfs/assets/ embedfs/assets/
embedfs/frontend-dist/ embedfs/frontend-dist/
backend/cmd/desktop/rsrc_windows_*.syso

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"files.eol": "\n"
}

184
LICENSE Normal file
View File

@@ -0,0 +1,184 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form,
made available under the License, as indicated by a copyright notice that is
included in or attached to the work (an example is provided in the Appendix
below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this
License, each Contributor hereby grants to You a perpetual, worldwide,
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
reproduce, prepare Derivative Works of, publicly display, publicly perform,
sublicense, and distribute the Work and such Derivative Works in Source or
Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License,
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section) patent
license to make, have made, use, offer to sell, sell, import, and otherwise
transfer the Work, where such license applies only to those patent claims
licensable by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s) with the Work
to which such Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a lawsuit)
alleging that the Work or a Contribution incorporated within the Work
constitutes direct or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate as of the date
such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or
Derivative Works thereof in any medium, with or without modifications, and in
Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of
this License; and
(b) You must cause any modified files to carry prominent notices stating that
You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You
distribute, all copyright, patent, trademark, and attribution notices from the
Source form of the Work, excluding those notices that do not pertain to any part
of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
any Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any
Contribution intentionally submitted for inclusion in the Work by You to the
Licensor shall be under the terms and conditions of this License, without any
additional terms or conditions. Notwithstanding the above, nothing herein shall
supersede or modify the terms of any separate license agreement you may have
executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names,
trademarks, service marks, or product names of the Licensor, except as required
for reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
writing, Licensor provides the Work (and each Contributor provides its
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied, including, without limitation, any warranties
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any risks
associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in
tort (including negligence), contract, or otherwise, unless required by
applicable law (such as deliberate and grossly negligent acts) or agreed to in
writing, shall any Contributor be liable to You for damages, including any
direct, indirect, special, incidental, or consequential damages of any character
arising as a result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill, work stoppage,
computer failure or malfunction, or any and all other commercial damages or
losses), even if such Contributor has been advised of the possibility of such
damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or
Derivative Works thereof, You may choose to offer, and charge a fee for,
acceptance of support, warranty, indemnity, or other liability obligations
and/or rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if You agree to
indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

366
Makefile
View File

@@ -1,199 +1,251 @@
.PHONY: all dev build test lint clean \ .PHONY: \
backend-build backend-run backend-dev backend-test backend-test-unit backend-test-integration backend-test-coverage \ lint test clean \
backend-lint backend-clean backend-deps backend-generate \ version-sync version-check version-bump \
backend-db-up backend-db-down backend-db-status backend-db-create \ server-run server-build server-lint server-test server-clean \
test-mysql-up test-mysql-down test-mysql test-mysql-quick \ desktop-build-mac desktop-build-win desktop-build-linux \
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint frontend-clean \ desktop-lint desktop-test desktop-clean \
desktop-build desktop-build-mac desktop-build-win desktop-build-linux \ release-assets-linux release-assets-windows release-assets-macos \
desktop-dev desktop-package-mac desktop-package-win desktop-package-linux desktop-clean \ _backend-lint _backend-test _backend-clean _backend-build \
desktop-prepare-frontend desktop-prepare-embedfs _versionctl-lint _versionctl-test \
_frontend-install _frontend-build _frontend-check _frontend-test _frontend-dev _frontend-clean \
_desktop-test _desktop-clean _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource \
_server-run-backend _server-run-frontend
# Delay shell lookups until a target needs them, then cache the result for this make run.
lazy_shell = $(or $($(1)),$(eval $(1) := $(shell $(2)))$($(1)))
VERSION = $(call lazy_shell,_VERSION,go run ./versionctl print)
GIT_COMMIT ?= $(call lazy_shell,_GIT_COMMIT,git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
BUILD_TIME ?= $(call lazy_shell,_BUILD_TIME,date -u +"%Y-%m-%dT%H:%M:%SZ")
GO_LDFLAGS = -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
GO_LDFLAGS_WIN = $(GO_LDFLAGS) -H=windowsgui
RELEASE_DIR := build/release
SERVER_LINUX_ASSET = $(call lazy_shell,_SERVER_LINUX_ASSET,go run ./versionctl asset-name server linux amd64)
SERVER_WINDOWS_ASSET = $(call lazy_shell,_SERVER_WINDOWS_ASSET,go run ./versionctl asset-name server windows amd64)
SERVER_DARWIN_AMD64_ASSET = $(call lazy_shell,_SERVER_DARWIN_AMD64_ASSET,go run ./versionctl asset-name server darwin amd64)
SERVER_DARWIN_ARM64_ASSET = $(call lazy_shell,_SERVER_DARWIN_ARM64_ASSET,go run ./versionctl asset-name server darwin arm64)
DESKTOP_LINUX_ASSET = $(call lazy_shell,_DESKTOP_LINUX_ASSET,go run ./versionctl asset-name desktop linux)
DESKTOP_WINDOWS_ASSET = $(call lazy_shell,_DESKTOP_WINDOWS_ASSET,go run ./versionctl asset-name desktop windows)
DESKTOP_MACOS_ASSET = $(call lazy_shell,_DESKTOP_MACOS_ASSET,go run ./versionctl asset-name desktop macos)
# ============================================ # ============================================
# 顶层便捷命令 # 全局命令
# ============================================ # ============================================
dev: lint: _backend-lint _frontend-check _versionctl-lint
@echo "🚀 Starting development environment..." @printf 'Lint complete\n'
@$(MAKE) -j2 backend-dev frontend-dev
build: backend-build frontend-build test: _backend-test _frontend-test _desktop-test _versionctl-test
@echo "✅ Build complete" @printf 'All tests passed\n'
test: backend-test frontend-test clean: _backend-clean _frontend-clean _desktop-clean
@echo "✅ All tests passed" @printf 'Clean complete\n'
lint: backend-lint frontend-lint
@echo "✅ Lint complete"
all: build test lint
# ============================================ # ============================================
# 后端 # 版本管理
# ============================================ # ============================================
backend-build: version-sync:
cd backend && go build -o bin/server ./cmd/server go run ./versionctl sync
backend-run: version-check:
cd backend && go run ./cmd/server go run ./versionctl check
backend-dev: version-bump: BUMP ?= patch
cd backend && go run ./cmd/server version-bump:
$(eval _BUMP_ARG := $(if $(SET_VERSION),$(SET_VERSION),$(BUMP)))
backend-test: $(eval _NEW_VERSION := $(shell go run ./versionctl bump $(_BUMP_ARG)))
cd backend && go test ./... -v git add VERSION frontend/
git commit -m "chore: 版本升迁 v$(_NEW_VERSION)"
backend-test-unit: git tag "v$(_NEW_VERSION)"
cd backend && go test ./internal/... ./pkg/... -v @printf '版本升迁完成: v%s\n' "$(_NEW_VERSION)"
backend-test-integration:
cd backend && go test ./tests/... -v
backend-test-coverage:
cd backend && go test ./... -coverprofile=coverage.out
cd backend && go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: backend/coverage.html"
backend-lint:
cd backend && go tool golangci-lint run ./...
backend-clean:
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
backend-deps:
cd backend && go mod tidy
backend-generate:
cd backend && go generate ./...
backend-db-up:
@echo "Running database migration up..."
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" up
backend-db-down:
@echo "Running database migration down..."
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" down
backend-db-status:
@echo "Checking database migration status..."
cd backend && goose -dir migrations/sqlite sqlite3 "$(DB_PATH)" status
backend-db-create:
@read -p "Migration name: " name; \
cd backend && goose -dir migrations/sqlite create $$name sql; \
cd backend && goose -dir migrations/mysql create $$name sql
# ============================================ # ============================================
# MySQL 专项测试 # Server 模式
# ============================================ # ============================================
test-mysql-up: server-run:
@echo "Starting MySQL test container..." @$(MAKE) -j2 _server-run-backend _server-run-frontend
cd backend/tests/mysql && docker-compose up -d
@echo "Waiting for MySQL to be ready..."
@for i in $$(seq 1 30); do \
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
echo "MySQL is ready!"; \
exit 0; \
fi; \
echo "Waiting... ($$i/30)"; \
sleep 1; \
done; \
echo "MySQL failed to start"; \
exit 1
test-mysql-down: server-build: version-check _backend-build _frontend-build
@echo "Stopping MySQL test container..." @printf 'Server build complete\n'
cd backend/tests/mysql && docker-compose down -v
test-mysql: test-mysql-up server-lint: _backend-lint _frontend-check
@echo "Running MySQL tests..." @printf 'Server lint complete\n'
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
$(MAKE) test-mysql-down
test-mysql-quick: server-test: _backend-test _frontend-test
@echo "Running MySQL tests (without container management)..." @printf 'Server tests passed\n'
cd backend && go test -tags=mysql ./tests/mysql/... -v -count=1
server-clean: _backend-clean _frontend-clean
@printf 'Server artifacts cleaned\n'
_server-run-backend:
@$(MAKE) -C backend run
_server-run-frontend: _frontend-install
cd frontend && bun run dev
# ============================================ # ============================================
# 前端 # Desktop 模式
# ============================================ # ============================================
frontend-build: desktop-build-mac: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
cd frontend && bun install && bun run build @printf 'Building macOS desktop...\n'
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-arm64 ./cmd/desktop
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-amd64 ./cmd/desktop
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
@printf 'Packaging macOS app bundle...\n'
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
@if [ -f assets/icon.icns ]; then \
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
else \
printf 'Missing assets/icon.icns\n'; \
fi
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
if [ -z "$$MIN_MACOS_VERSION" ]; then \
printf 'Unable to read macOS minimum version\n'; \
exit 1; \
fi; \
go run ./versionctl macos-plist "$$MIN_MACOS_VERSION" > build/Nex.app/Contents/Info.plist
chmod +x build/Nex.app/Contents/MacOS/nex
@printf 'macOS desktop build complete\n'
frontend-dev: desktop-build-win: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource
cd frontend && bun dev @printf 'Building Windows desktop...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "New-Item -ItemType Directory -Path 'build' -Force | Out-Null"
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
else
mkdir -p build
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
endif
@printf 'Windows desktop build complete\n'
frontend-test: desktop-build-linux: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
cd frontend && bun run test @printf 'Building Linux desktop...\n'
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-linux-amd64 ./cmd/desktop
@printf 'Linux desktop build complete\n'
frontend-test-watch: desktop-lint: _backend-lint _frontend-check
cd frontend && bun run test:watch @printf 'Desktop lint complete\n'
frontend-test-coverage: desktop-test: _desktop-test
cd frontend && bun run test:coverage @printf 'Desktop tests passed\n'
frontend-test-e2e: desktop-clean: _desktop-clean
cd frontend && bun run test:e2e @printf 'Desktop artifacts cleaned\n'
frontend-lint: _desktop-test:
cd frontend && bun run lint cd backend && go test ./cmd/desktop/... -v
frontend-clean: _desktop-clean:
rm -rf frontend/dist frontend/.next frontend/node_modules rm -rf build/ embedfs/assets embedfs/frontend-dist backend/cmd/desktop/rsrc_windows_amd64.syso
# ============================================ _desktop-prepare-frontend: _frontend-install
# 桌面应用 @printf 'Preparing frontend for desktop...\n'
# ============================================ ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Copy-Item -LiteralPath 'frontend/.env.desktop' -Destination 'frontend/.env.production.local' -Force"
desktop-build: desktop-build-mac desktop-build-win desktop-build-linux cd frontend && bun run build
@echo "✅ Desktop builds complete for all platforms" powershell -NoProfile -Command "Remove-Item -LiteralPath 'frontend/.env.production.local' -Force -ErrorAction SilentlyContinue"
else
desktop-prepare-frontend:
@echo "📦 Preparing frontend for desktop..."
cd frontend && cp .env.desktop .env.production.local cd frontend && cp .env.desktop .env.production.local
cd frontend && bun install && bun run build cd frontend && bun run build
rm -f frontend/.env.production.local rm -f frontend/.env.production.local
endif
desktop-prepare-embedfs: _desktop-prepare-embedfs:
@echo "📦 Preparing embedded filesystem..." @printf 'Preparing embedded filesystem...\n'
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Remove-Item -LiteralPath 'embedfs/assets' -Recurse -Force -ErrorAction SilentlyContinue; Remove-Item -LiteralPath 'embedfs/frontend-dist' -Recurse -Force -ErrorAction SilentlyContinue; Copy-Item -LiteralPath 'assets' -Destination 'embedfs/assets' -Recurse; Copy-Item -LiteralPath 'frontend/dist' -Destination 'embedfs/frontend-dist' -Recurse"
else
rm -rf embedfs/assets embedfs/frontend-dist rm -rf embedfs/assets embedfs/frontend-dist
cp -r assets embedfs/assets cp -r assets embedfs/assets
cp -r frontend/dist embedfs/frontend-dist cp -r frontend/dist embedfs/frontend-dist
endif
desktop-build-mac: desktop-prepare-frontend desktop-prepare-embedfs _desktop-prepare-windows-resource:
@echo "🍎 Building macOS..." @printf 'Preparing Windows executable icon...\n'
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-mac-arm64 ./cmd/desktop ifeq ($(OS),Windows_NT)
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-mac-amd64 ./cmd/desktop cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso
else
desktop-build-win: desktop-prepare-frontend desktop-prepare-embedfs @if command -v x86_64-w64-mingw32-windres >/dev/null 2>&1; then \
@echo "🪟 Building Windows..." cd backend/cmd/desktop && x86_64-w64-mingw32-windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-win-amd64.exe ./cmd/desktop elif command -v windres >/dev/null 2>&1; then \
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
desktop-build-linux: desktop-prepare-frontend desktop-prepare-embedfs else \
@echo "🐧 Building Linux..." printf 'Missing windres for Windows icon resource generation\n'; \
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop exit 1; \
fi
desktop-dev: desktop-prepare-frontend desktop-prepare-embedfs endif
@echo "🖥️ Starting desktop app in dev mode..."
cd backend && go run ./cmd/desktop
desktop-package-mac:
./scripts/build/package-macos.sh
desktop-package-win:
@echo "⚠️ Windows packaging not implemented yet"
desktop-package-linux:
@echo "⚠️ Linux packaging not implemented yet"
desktop-clean:
rm -rf build/ embedfs/assets embedfs/frontend-dist
# ============================================ # ============================================
# 清理 # 发布资产
# ============================================ # ============================================
clean: backend-clean frontend-clean desktop-clean release-assets-linux: version-check desktop-build-linux
@echo "✅ Clean complete" rm -rf "$(RELEASE_DIR)"
mkdir -p "$(RELEASE_DIR)"
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-linux-amd64 ./cmd/server
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_LINUX_ASSET)" nex-server-linux-amd64
tar -C build -czf "$(RELEASE_DIR)/$(DESKTOP_LINUX_ASSET)" nex-linux-amd64
release-assets-windows: version-check desktop-build-win
ifeq ($(OS),Windows_NT)
powershell -NoProfile -Command "Remove-Item -LiteralPath '$(RELEASE_DIR)' -Recurse -Force -ErrorAction SilentlyContinue; New-Item -ItemType Directory -Path '$(RELEASE_DIR)' -Force | Out-Null"
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-server-win-amd64.exe ./cmd/server
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-server-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(SERVER_WINDOWS_ASSET)' -Force"
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(DESKTOP_WINDOWS_ASSET)' -Force"
else
@printf 'release-assets-windows requires Windows\n'
@exit 1
endif
release-assets-macos: version-check desktop-build-mac
rm -rf "$(RELEASE_DIR)"
mkdir -p "$(RELEASE_DIR)"
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-amd64 ./cmd/server
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-arm64 ./cmd/server
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_AMD64_ASSET)" nex-server-darwin-amd64
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_ARM64_ASSET)" nex-server-darwin-arm64
ditto -c -k --keepParent build/Nex.app "$(RELEASE_DIR)/$(DESKTOP_MACOS_ASSET)"
# ============================================
# 共享 helper targets
# ============================================
_backend-build:
@$(MAKE) -C backend build
_backend-lint:
@$(MAKE) -C backend lint
_backend-test:
@$(MAKE) -C backend test
_backend-clean:
@$(MAKE) -C backend clean
_versionctl-lint:
@$(MAKE) -C versionctl lint
_versionctl-test:
@$(MAKE) -C versionctl test
_frontend-install:
cd frontend && bun install
_frontend-build: _frontend-install
cd frontend && bun run build
_frontend-check: _frontend-install
cd frontend && bun run check
_frontend-test: _frontend-install
cd frontend && bun run test
_frontend-dev: _frontend-install
cd frontend && bun run dev
_frontend-clean:
rm -rf frontend/dist frontend/.next frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo

179
README.md
View File

@@ -36,13 +36,9 @@ nex/
├── assets/ # 应用资源 ├── assets/ # 应用资源
│ ├── icon.png # 托盘图标 │ ├── icon.png # 托盘图标
│ ├── AppIcon.icns # macOS 应用图标 │ ├── icon.icns # macOS 应用图标
│ └── icon.ico # Windows 应用图标 │ └── icon.ico # Windows 应用图标
├── scripts/ # 构建脚本
│ └── build/
│ └── package-macos.sh # macOS .app 打包脚本
└── README.md # 本文件 └── README.md # 本文件
``` ```
@@ -51,7 +47,7 @@ nex/
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议 - **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
- **跨协议转换**Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换 - **跨协议转换**Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
- **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4` - **统一模型 ID**`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段 - **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换 - **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
- **Function Calling**支持工具调用Tools - **Function Calling**支持工具调用Tools
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置 - **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
@@ -95,7 +91,7 @@ JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":
- **图表库**: Recharts - **图表库**: Recharts
- **路由**: React Router v7 - **路由**: React Router v7
- **数据获取**: TanStack Query v5 - **数据获取**: TanStack Query v5
- **样式**: SCSS Modules - **样式**: TDesign 组件 props 优先TDesign tokens 次之SCSS 作为兜底补充
- **测试**: Vitest + React Testing Library + Playwright - **测试**: Vitest + React Testing Library + Playwright
## 快速开始 ## 快速开始
@@ -105,18 +101,14 @@ JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":
**构建桌面应用** **构建桌面应用**
```bash ```bash
# macOS (arm64 + amd64) # macOS (arm64 + amd64,并打包为 .app)
make desktop-build-mac make desktop-build-mac
make desktop-package-mac # 打包为 .app
# Windows # Windows
make desktop-build-win make desktop-build-win
# Linux # Linux
make desktop-build-linux make desktop-build-linux
# 构建所有平台
make desktop-build
``` ```
**使用桌面应用** **使用桌面应用**
@@ -137,50 +129,54 @@ make desktop-build
- Xfce: 需要 libappindicator - Xfce: 需要 libappindicator
- 其他支持 StatusNotifierItem 规范的环境 - 其他支持 StatusNotifierItem 规范的环境
### CLI 模式 ### Server 模式(前后端分离)
#### 后端
```bash ```bash
cd backend make server-run
go mod download
go run cmd/server/main.go
``` ```
后端服务将在 `http://localhost:9826` 启动。首次启动会自动: `make server-run` 会并行启动:
- 后端服务:`http://localhost:9826`
- 前端开发服务器:`http://localhost:5173`
前端请求会继续通过 Vite proxy 转发到后端。后端首次启动会自动:
- 创建配置文件 `~/.nex/config.yaml` - 创建配置文件 `~/.nex/config.yaml`
- 初始化数据库 `~/.nex/config.db` - 初始化数据库 `~/.nex/config.db`
- 运行数据库迁移 - 运行数据库迁移
- 创建日志目录 `~/.nex/log/` - 创建日志目录 `~/.nex/log/`
### 前端 **构建 server 模式产物**
```bash ```bash
cd frontend make server-build
bun install
bun dev
``` ```
前端开发服务器将在 `http://localhost:5173` 启动API 请求通过 Vite proxy 转发到后端。
## API 接口 ## API 接口
### 代理接口(对外部应用) ### 代理接口(对外部应用)
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。 代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough最小化 JSON 改写保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
**OpenAI 协议**`protocol=openai` **OpenAI 协议**`protocol=openai`
- `POST /openai/chat/completions` - 对话补全 - `POST /openai/v1/chat/completions` - 对话补全
- `GET /openai/models` - 模型列表(本地数据库聚合) - `GET /openai/v1/models` - 模型列表(本地数据库聚合)
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询) - `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
- `POST /openai/embeddings` - 嵌入 - `POST /openai/v1/embeddings` - 嵌入
- `POST /openai/rerank` - 重排序 - `POST /openai/v1/rerank` - 重排序
**Anthropic 协议**`protocol=anthropic` **Anthropic 协议**`protocol=anthropic`
- `POST /anthropic/v1/messages` - 消息对话 - `POST /anthropic/v1/messages` - 消息对话
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合) - `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询) - `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions``/v1/models``/v1/embeddings``/v1/rerank`,并在构建上游 URL 时去掉 `/v1`Anthropic adapter 接收 `/v1/messages``/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON``MODEL_NOT_FOUND``CONVERSION_FAILED``UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
### 管理接口(对前端) ### 管理接口(对前端)
#### 供应商管理 #### 供应商管理
@@ -203,6 +199,9 @@ bun dev
查询参数支持:`provider_id``model_name``start_date``end_date``group_by` 查询参数支持:`provider_id``model_name``start_date``end_date``group_by`
#### 版本信息
- `GET /api/version` - 获取后端构建版本信息(`version``commit``build_time`),用于前端 About 页面诊断前后端版本一致性
## 配置 ## 配置
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值** 配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
@@ -276,50 +275,100 @@ export NEX_DATABASE_DBNAME=nex
## 测试 ## 测试
```bash ```bash
# 顶层便捷命令 # 全局默认测试(不含 MySQL 和前端 E2E
make test # 运行所有测试 make test
# 后端测试 # 产品级测试
make backend-test # 后端测试 make server-test
make backend-test-coverage # 后端覆盖率 make desktop-test
make backend-test-unit # 后端单元测试
make backend-test-integration # 后端集成测试
# 前端测试
make frontend-test # 前端测试
make frontend-test-e2e # 前端 E2E 测试
make frontend-test-coverage # 前端覆盖率
``` ```
backend 分类测试、MySQL 专项测试和前端 E2E 测试请分别查看 `backend/README.md``frontend/README.md`
## 开发 ## 开发
```bash ```bash
# 顶层便捷命令 # 首次克隆后安装 Git hooks
make dev # 启动开发环境(并行启动后端和前端) lefthook install
make build # 构建所有产物
make lint # 检查所有代码
make clean # 清理所有构建产物
# 后端开发 # 全局命令
make backend-build # 构建后端 make lint # 前后端共享检查
make backend-run # 运行后端 make test # 默认全量测试(不含 MySQL/E2E
make backend-dev # 后端开发模式 make clean # 清理所有构建产物和测试报告
make backend-lint # 后端代码检查
make backend-clean # 清理后端构建产物
# 数据库操作 # server 模式
make backend-db-up # 数据库迁移 make server-run # 并行启动后端和前端开发服务
make backend-db-down # 数据库回滚 make server-build # 构建 backend/bin/server 和 frontend/dist
make backend-db-status # 数据库迁移状态 make server-lint # server 模式检查
make backend-db-create # 创建新迁移 make server-test # server 模式测试
make server-clean # 清理 server 模式产物
# 前端开发 # desktop 模式
make frontend-build # 构建前端 make desktop-build-mac # 构建 macOS 桌面应用
make frontend-dev # 前端开发模式 make desktop-build-win # 构建 Windows 桌面应用
make frontend-lint # 前端代码检查 make desktop-build-linux # 构建 Linux 桌面应用
make frontend-clean # 清理前端构建产物 make desktop-lint # desktop 模式检查
make desktop-test # desktop 专属测试
make desktop-clean # 清理 desktop 产物
``` ```
## 版本与发布
### 统一版本源
- 仓库根目录 `VERSION` 是全仓唯一版本源,格式固定为 `x.y.z`
- `frontend/package.json` 和前端 `.env.*` 中的 `VITE_APP_VERSION` 由仓库工具同步,不能手工漂移
### 本地版本演进
```bash
# 递增版本(自动 sync + check + commit + tag
make version-bump BUMP=minor
# 或指定具体版本号
make version-bump SET_VERSION=1.0.0
# 推送到远程
git push --follow-tags
```
手动同步和校验:
```bash
make version-sync
make version-check
```
### 本地生成发布资产
```bash
# Linux: server + desktop
make release-assets-linux
# Windows: server + desktop需在 Windows 环境执行)
make release-assets-windows
# macOS: darwin-amd64 server、darwin-arm64 server、desktop universal
make release-assets-macos
```
生成的版本化发布资产位于 `build/release/`
### GitHub Draft Release
- 推送 `vX.Y.Z` tag 后,`.github/workflows/release.yml` 会自动执行发布流水线
- 三个平台 job 会在正式构建前先检查 `go``bun` 和各自的平台打包工具链,缺失时快速失败并在日志中输出诊断信息
- Windows 发布 job 在 `MSYS2 / MINGW64` shell 中执行,并继承 `setup-go` / `setup-bun` 准备好的工具链路径
- 流水线会先校验 tag 与 `VERSION` 一致,再构建以下资产并上传到 GitHub Draft Release
- Linux server
- Windows server
- darwin-amd64 server
- darwin-arm64 server
- Linux desktop
- Windows desktop
- macOS desktop universal
- Release 默认以 Draft 形式创建,需人工检查后再公开发布
## 开发规范 ## 开发规范
详见各子项目的 README.md 详见各子项目的 README.md
@@ -328,4 +377,4 @@ make frontend-clean # 清理前端构建产物
## 许可证 ## 许可证
MIT Apache License 2.0

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.1.1

Binary file not shown.

View File

@@ -1,64 +0,0 @@
# Assets
应用资源文件目录。
## 文件说明
| 文件 | 用途 | 尺寸 | 格式 |
|------|------|------|------|
| `icon.svg` | 源图标 | 64x64 | SVG |
| `icon.png` | 托盘图标 | 64x64 | PNG |
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
## 替换图标
### 1. 准备图标
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
### 2. 生成各平台图标
**托盘图标 (PNG)**
```bash
magick your-icon.svg -resize 64x64 icon.png
```
**macOS 应用图标 (ICNS)**
```bash
mkdir icon.iconset
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
iconutil -c icns icon.iconset -o AppIcon.icns
rm -rf icon.iconset
```
**Windows 应用图标 (ICO)**
```bash
magick your-icon.svg -resize 256x256 icon.ico
```
### 3. 替换文件
将生成的文件放入此目录,然后重新构建桌面应用:
```bash
./scripts/build/build-darwin-arm64.sh
```
## macOS Template 图标
macOS 支持 Template 图标,自动适配深浅色模式:
- 使用黑色 + 透明设计
- 文件名以 `Template` 结尾(如 `iconTemplate.png`
- 黑色在深色模式下自动变为白色
## 设计建议
- 托盘图标应简洁,在小尺寸下清晰可辨
- 避免过多细节和文字
- 使用高对比度颜色
- macOS 建议使用 Template 图标风格

BIN
assets/icon.icns LFS Normal file

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 264 KiB

After

Width:  |  Height:  |  Size: 130 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

After

Width:  |  Height:  |  Size: 131 B

View File

@@ -1,13 +0,0 @@
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
<circle cx="32" cy="32" r="6" fill="white"/>
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
<circle cx="20" cy="20" r="3" fill="white"/>
<circle cx="44" cy="20" r="3" fill="white"/>
<circle cx="20" cy="44" r="3" fill="white"/>
<circle cx="44" cy="44" r="3" fill="white"/>
</svg>

Before

Width:  |  Height:  |  Size: 779 B

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

91
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,91 @@
run:
timeout: 5m
tests: true
linters:
disable-all: true
enable:
- forbidigo
- errorlint
- errcheck
- staticcheck
- revive
- gocritic
- gosec
- bodyclose
- noctx
- nilerr
- goimports
- gocyclo
linters-settings:
errcheck:
check-blank: true
check-type-assertions: true
exclude-functions:
- fmt.Fprintf
forbidigo:
analyze-types: true
forbid:
- p: '^fmt\.Print.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^fmt\.Fprint.*$'
msg: 使用 zap logger不要直接输出到 stdout/stderr
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
msg: 使用 zap logger不要使用标准库 log
- p: '^zap\.L$'
msg: 通过依赖注入传递 *zap.Logger不要使用全局 logger
- p: '^zap\.S$'
msg: 不使用 Sugar logger
revive:
rules:
- name: exported
- name: var-naming
- name: indent-error-flow
- name: error-strings
- name: error-return
- name: blank-imports
- name: context-as-argument
- name: unexported-return
goimports:
local-prefixes: nex/backend
gocyclo:
min-complexity: 10
issues:
exclude-dirs:
- tests/mocks
exclude-generated: true
exclude-rules:
- path: '(_test\.go|tests/)'
linters:
- forbidigo
- path: '(_test\.go|tests/)'
linters:
- errcheck
source: '(^\s*_\s*=|,\s*_)'
- path: 'tests/integration/e2e_conversion_test\.go'
linters:
- errcheck
- path: '(_test\.go|tests/)'
linters:
- revive
text: '^exported:'
- path: '(_test\.go|tests/)'
linters:
- gosec
text: 'G(101|401|501)'
- path: '(_test\.go|tests/)'
linters:
- gocyclo
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
- linters:
- revive
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
- path: 'internal/conversion/.*\.go'
linters:
- gocyclo
- gocritic
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
linters:
- gocyclo

97
backend/Makefile Normal file
View File

@@ -0,0 +1,97 @@
.PHONY: \
build run \
test test-unit test-integration test-coverage \
lint clean \
migrate-up migrate-down migrate-status migrate-create \
mysql-up mysql-down mysql-test mysql-test-quick
VERSION := $(shell go run ../versionctl print)
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
DB_DRIVER ?= sqlite3
DB_DSN ?= $(HOME)/.nex/config.db
ifeq ($(DB_DRIVER),mysql)
GOOSE_DIR := migrations/mysql
GOOSE_DRIVER := mysql
else ifeq ($(DB_DRIVER),sqlite3)
GOOSE_DIR := migrations/sqlite
GOOSE_DRIVER := sqlite3
else
$(error unsupported DB_DRIVER '$(DB_DRIVER)', use sqlite3 or mysql)
endif
build:
go build -ldflags "$(GO_LDFLAGS)" -o bin/server ./cmd/server
run:
go run -ldflags "$(GO_LDFLAGS)" ./cmd/server
test:
go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
test-unit:
go test ./internal/... ./pkg/... -v
test-integration:
go test ./tests/... -v
test-coverage:
go test ./... -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
@printf 'Coverage report generated: backend/coverage.html\n'
lint:
go tool golangci-lint run ./...
clean:
rm -rf bin/ coverage.out coverage.html
migrate-up:
@printf 'Running database migration up...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" up
migrate-down:
@printf 'Running database migration down...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" down
migrate-status:
@printf 'Checking database migration status...\n'
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" status
migrate-create:
@printf 'Migration name: '; \
read name; \
goose -dir migrations/sqlite create $$name sql; \
goose -dir migrations/mysql create $$name sql
mysql-up:
@printf 'Starting MySQL test container...\n'
cd tests/mysql && docker-compose up -d
@printf 'Waiting for MySQL to be ready...\n'
@for i in $$(seq 1 30); do \
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
printf 'MySQL is ready\n'; \
exit 0; \
fi; \
printf 'Waiting... (%s/30)\n' $$i; \
sleep 1; \
done; \
printf 'MySQL failed to start\n'; \
exit 1
mysql-down:
@printf 'Stopping MySQL test container...\n'
cd tests/mysql && docker-compose down -v
mysql-test:
@set -e; \
$(MAKE) mysql-up; \
trap '$(MAKE) mysql-down' EXIT; \
go test -tags=mysql ./tests/mysql/... -v -count=1
mysql-test-quick:
@printf 'Running MySQL tests without container management...\n'
go test -tags=mysql ./tests/mysql/... -v -count=1

View File

@@ -4,10 +4,10 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
## 功能特性 ## 功能特性
- 支持 OpenAI 协议(`/openai/v1/...` - 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`
- 支持 Anthropic 协议(`/anthropic/v1/...` - 支持 Anthropic 协议(`/anthropic/v1/...`
- 支持 Hub-and-Spoke 跨协议双向转换OpenAI ↔ Anthropic - 支持 Hub-and-Spoke 跨协议双向转换OpenAI ↔ Anthropic
- 同协议透传(零语义损失、零序列化开销 - 同协议透传(跳过 Canonical 全量转换,保持协议语义
- 支持流式响应SSE - 支持流式响应SSE
- 支持 Function Calling / Tools - 支持 Function Calling / Tools
- 支持 Thinking / Reasoning - 支持 Thinking / Reasoning
@@ -54,7 +54,7 @@ func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
使用 `pkg/logger/field.go` 中定义的字段构造函数: 使用 `pkg/logger/field.go` 中定义的字段构造函数:
```go ```go
logger.Info("请求开始", logger.Debug("请求开始",
pkglogger.Method("POST"), pkglogger.Method("POST"),
pkglogger.Path("/v1/chat"), pkglogger.Path("/v1/chat"),
pkglogger.RequestID("xxx"), pkglogger.RequestID("xxx"),
@@ -220,7 +220,7 @@ OpenAI Response ← Canonical Response ← Anthropic Response
### Smart Passthrough 机制 ### Smart Passthrough 机制
同协议请求走 Smart Passthrough 路径,**零序列化开销** 同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换
``` ```
1. 检测 clientProtocol == providerProtocol 1. 检测 clientProtocol == providerProtocol
@@ -229,12 +229,14 @@ OpenAI Response ← Canonical Response ← Anthropic Response
4. 响应中仅改写 model 字段upstream_model_name → unified_id 4. 响应中仅改写 model 字段upstream_model_name → unified_id
``` ```
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
### 流式转换器层次 ### 流式转换器层次
``` ```
StreamConverter (接口) StreamConverter (接口)
├── PassthroughStreamConverter # 直接透传,无任何处理 ├── PassthroughStreamConverter # 直接透传,无任何处理
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model ├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
└── CanonicalStreamConverter # 跨协议完整转换decode → encode └── CanonicalStreamConverter # 跨协议完整转换decode → encode
``` ```
@@ -301,6 +303,7 @@ StreamConverter (接口)
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 | | `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
| `ENCODING_FAILURE` | 编码失败 | | `ENCODING_FAILURE` | 编码失败 |
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings | | `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings |
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
### AppError 预定义错误 ### AppError 预定义错误
@@ -434,24 +437,37 @@ docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
## 测试 ## 测试
```bash ```bash
# 运行所有测试 # 运行 backend 默认测试
make test make test
# 分类测试
make test-unit
make test-integration
# 生成覆盖率报告 # 生成覆盖率报告
make test-coverage make test-coverage
# MySQL 专项测试
make mysql-up
make mysql-down
make mysql-test
make mysql-test-quick
``` ```
## 数据库迁移 ## 数据库迁移
```bash ```bash
# 使用 Makefile # 使用 Makefile
make migrate-up DB_PATH=~/.nex/config.db make migrate-up DB_DSN=~/.nex/config.db
make migrate-down DB_PATH=~/.nex/config.db make migrate-down DB_DSN=~/.nex/config.db
make migrate-status DB_PATH=~/.nex/config.db make migrate-status DB_DSN=~/.nex/config.db
# 创建新迁移 # 创建新迁移
make migrate-create make migrate-create
# MySQL 迁移
make migrate-up DB_DRIVER=mysql DB_DSN='user:pass@tcp(localhost:3306)/nex'
# 或直接使用 goose # 或直接使用 goose
goose -dir migrations sqlite3 ~/.nex/config.db up goose -dir migrations sqlite3 ~/.nex/config.db up
``` ```
@@ -460,15 +476,15 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
### 代理接口 ### 代理接口
使用 `/{protocol}/v1/{path}` URL 前缀路由 使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath由对应 adapter 识别和组合上游 URL。
#### OpenAI 协议 #### OpenAI 协议
``` ```
POST /openai/chat/completions POST /openai/v1/chat/completions
GET /openai/models GET /openai/v1/models
POST /openai/embeddings POST /openai/v1/embeddings
POST /openai/rerank POST /openai/v1/rerank
``` ```
#### Anthropic 协议 #### Anthropic 协议
@@ -478,10 +494,20 @@ POST /anthropic/v1/messages
GET /anthropic/v1/models GET /anthropic/v1/models
``` ```
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销 **协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough跳过 Canonical 全量转换
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。 **统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
**base_url 约定**
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions`OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
**流式透传边界**:同协议无响应 model 改写时 raw passthrough保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`
### 管理接口 ### 管理接口
#### 供应商管理 #### 供应商管理
@@ -509,7 +535,7 @@ GET /anthropic/v1/models
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com` - Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
**对外 URL 格式** **对外 URL 格式**
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions``/openai/models``/openai/embeddings` - OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions``/openai/v1/models``/openai/v1/embeddings`
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages``/anthropic/v1/models` - Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages``/anthropic/v1/models`
#### 模型管理 #### 模型管理
@@ -551,6 +577,20 @@ GET /anthropic/v1/models
查询参数:`provider_id``model_name``start_date`YYYY-MM-DD`end_date``group_by`provider/model/date 查询参数:`provider_id``model_name``start_date`YYYY-MM-DD`end_date``group_by`provider/model/date
#### 版本信息
- `GET /api/version` - 获取后端构建版本信息
响应字段来源于构建阶段注入的 `buildinfo` 元数据:
```json
{
"version": "0.1.0",
"commit": "abc1234",
"build_time": "2026-05-05T00:00:00Z"
}
```
#### 健康检查 #### 健康检查
- `GET /health` - 返回 `{"status": "ok"}` - `GET /health` - 返回 `{"status": "ok"}`
@@ -558,9 +598,12 @@ GET /anthropic/v1/models
## 开发 ## 开发
```bash ```bash
make build # 构建 make build # 构建 backend/bin/server
make lint # 代码检查 make run # 运行后端服务
make deps # 整理依赖 make lint # 代码检查
make clean # 清理 backend 构建产物
go mod tidy # 整理依赖
go generate ./... # 刷新 mock 等生成代码
``` ```
环境要求Go 1.26 或更高版本 环境要求Go 1.26 或更高版本
@@ -609,6 +652,7 @@ err := v.Validate(myStruct)
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节 - **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接 - **字符串拼接**:使用 `strings.Join`,不手写循环拼接
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配 - **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配lint 强约束errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()` - **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片 - **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片

View File

@@ -6,19 +6,16 @@ import (
"fmt" "fmt"
"os/exec" "os/exec"
"strings" "strings"
"go.uber.org/zap"
) )
func showError(title, message string) { func showError(title, message string) {
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
escapeAppleScript(message), escapeAppleScript(title)) escapeAppleScript(message), escapeAppleScript(title))
exec.Command("osascript", "-e", script).Run() if err := exec.Command("osascript", "-e", script).Run(); err != nil {
} dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
}
func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`,
escapeAppleScript(message))
exec.Command("osascript", "-e", script).Run()
} }
func escapeAppleScript(s string) string { func escapeAppleScript(s string) string {

View File

@@ -4,7 +4,6 @@ package main
import ( import (
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sync" "sync"
) )
@@ -63,26 +62,6 @@ func showError(title, message string) {
exec.Command("xmessage", "-center", exec.Command("xmessage", "-center",
fmt.Sprintf("%s: %s", title, message)).Run() fmt.Sprintf("%s: %s", title, message)).Run()
default: default:
fmt.Fprintf(os.Stderr, "错误: %s: %s\n", title, message) dialogLogger().Error("无法显示错误对话框")
}
}
func showAbout() {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
switch dialogTool {
case toolZenity:
exec.Command("zenity", "--info",
"--title=关于 Nex Gateway",
fmt.Sprintf("--text=%s", message)).Run()
case toolKdialog:
exec.Command("kdialog", "--msgbox", message, "--title", "关于 Nex Gateway").Run()
case toolNotifySend:
exec.Command("notify-send", "关于 Nex Gateway", message).Run()
case toolXmessage:
exec.Command("xmessage", "-center",
fmt.Sprintf("关于 Nex Gateway: %s", message)).Run()
default:
fmt.Fprintf(os.Stderr, "关于 Nex Gateway: %s\n", message)
} }
} }

View File

@@ -0,0 +1,15 @@
package main
import (
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
)
func dialogLogger() *zap.Logger {
if zapLogger != nil {
return zapLogger
}
return pkgLogger.NewMinimal()
}

View File

@@ -3,36 +3,60 @@
package main package main
import ( import (
"errors"
"fmt"
"syscall" "syscall"
"unsafe" "unsafe"
"go.uber.org/zap"
) )
const ( const (
MB_ICONERROR = 0x10 mbIconError = 0x10
MB_ICONINFORMATION = 0x40 mbIconInformation = 0x40
) )
var ( var (
user32 = syscall.NewLazyDLL("user32.dll") user32 = syscall.NewLazyDLL("user32.dll")
procMessageBoxW = user32.NewProc("MessageBoxW") procMessageBoxW = user32.NewProc("MessageBoxW")
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
return ret, err
}
) )
func showError(title, message string) { func showError(title, message string) {
messageBox(title, message, MB_ICONERROR) if err := messageBox(title, message, mbIconError); err != nil {
if zapLogger != nil {
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
}
}
} }
func showAbout() { func messageBox(title, message string, flags uint) error {
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway" titlePtr, err := syscall.UTF16PtrFromString(title)
messageBox("关于 Nex Gateway", message, MB_ICONINFORMATION) if err != nil {
} return err
}
func messageBox(title, message string, flags uint) { messagePtr, err := syscall.UTF16PtrFromString(message)
titlePtr, _ := syscall.UTF16PtrFromString(title) if err != nil {
messagePtr, _ := syscall.UTF16PtrFromString(message) return err
procMessageBoxW.Call( }
ret, callErr := callMessageBoxW(
0, 0,
uintptr(unsafe.Pointer(messagePtr)), uintptr(unsafe.Pointer(messagePtr)),
uintptr(unsafe.Pointer(titlePtr)), uintptr(unsafe.Pointer(titlePtr)),
uintptr(flags), uintptr(flags),
) )
if ret != 0 {
return nil
}
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
return callErr
}
return fmt.Errorf("MessageBoxW 调用失败")
} }

View File

@@ -0,0 +1 @@
1 ICON "../../../assets/icon.ico"

View File

@@ -13,10 +13,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/getlantern/systray" "nex/embedfs"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -28,9 +25,14 @@ import (
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
pkgLogger "nex/backend/pkg/logger" "nex/backend/pkg/buildinfo"
"nex/embedfs" "github.com/getlantern/systray"
"github.com/gin-gonic/gin"
"github.com/gofrs/flock"
"go.uber.org/zap"
pkgLogger "nex/backend/pkg/logger"
) )
var ( var (
@@ -48,15 +50,19 @@ func main() {
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock")) singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
if err := singleLock.Lock(); err != nil { if err := singleLock.Lock(); err != nil {
minimalLogger.Error("已有 Nex 实例运行") minimalLogger.Error("已有 Nex 实例运行")
showError("Nex Gateway", "已有 Nex 实例运行") showError(appName, "已有 Nex 实例运行")
os.Exit(1) os.Exit(1)
} }
defer singleLock.Unlock() defer func() {
if err := singleLock.Unlock(); err != nil {
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
}
}()
if err := checkPortAvailable(port); err != nil { if err := checkPortAvailable(port); err != nil {
minimalLogger.Error("端口不可用", zap.Error(err)) minimalLogger.Error("端口不可用", zap.Error(err))
showError("Nex Gateway", err.Error()) showError(appName, err.Error())
os.Exit(1) return
} }
cfg, err := config.LoadConfig() cfg, err := config.LoadConfig()
@@ -75,7 +81,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)
@@ -120,6 +130,7 @@ func main() {
providerHandler := handler.NewProviderHandler(providerService) providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService) modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService) statsHandler := handler.NewStatsHandler(statsService)
versionHandler := handler.NewVersionHandler()
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
r := gin.New() r := gin.New()
@@ -129,7 +140,7 @@ func main() {
r.Use(middleware.Logging(zapLogger)) r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS()) r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler) setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
setupStaticFiles(r) setupStaticFiles(r)
server = &http.Server{ server = &http.Server{
@@ -142,24 +153,30 @@ func main() {
shutdownCtx, shutdownCancel = context.WithCancel(context.Background()) shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
go func() { go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr)) zapLogger.Info("AI Gateway 启动",
zap.String("addr", server.Addr),
zap.String("version", buildinfo.Version()),
zap.String("commit", buildinfo.Commit()),
zap.String("build_time", buildinfo.BuildTime()))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error())) zapLogger.Fatal("服务器启动失败", zap.Error(err))
} }
}() }()
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil { if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error())) zapLogger.Warn("无法打开浏览器", zap.Error(err))
} }
}() }()
setupSystray(port) setupSystray(port)
} }
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
r.Any("/v1/*path", proxyHandler.HandleProxy) r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
r.GET("/api/version", versionHandler.GetVersion)
providers := r.Group("/api/providers") providers := r.Group("/api/providers")
{ {
@@ -190,12 +207,26 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
}) })
} }
func setupStaticFiles(r *gin.Engine) { func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist") return func(c *gin.Context) {
if err != nil { c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error())) next(c)
} }
}
func setupStaticFiles(r *gin.Engine) {
distFS, err := frontendDistFS()
if err != nil {
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
}
setupStaticFilesWithFS(r, distFS)
}
func frontendDistFS() (fs.FS, error) {
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
}
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
getContentType := func(path string) string { getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") { if strings.HasSuffix(path, ".js") {
return "application/javascript" return "application/javascript"
@@ -228,20 +259,23 @@ func setupStaticFiles(r *gin.Engine) {
c.Data(200, getContentType(filepath), data) c.Data(200, getContentType(filepath), data)
}) })
r.GET("/favicon.svg", func(c *gin.Context) { r.GET("/icon.png", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg") data, err := fs.ReadFile(distFS, "icon.png")
if err != nil { if err != nil {
c.Status(404) c.Status(404)
return return
} }
c.Data(200, "image/svg+xml", data) c.Data(200, "image/png", data)
}) })
r.NoRoute(func(c *gin.Context) { r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") || if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/openai/") ||
strings.HasPrefix(path, "/anthropic/") ||
path == "/openai" ||
path == "/anthropic" ||
strings.HasPrefix(path, "/health") { strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"}) c.JSON(404, gin.H{"error": "not found"})
return return
@@ -266,11 +300,10 @@ func setupSystray(port int) {
icon, err = embedfs.Assets.ReadFile("assets/icon.png") icon, err = embedfs.Assets.ReadFile("assets/icon.png")
} }
if err != nil { if err != nil {
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error())) zapLogger.Error("无法加载托盘图标", zap.Error(err))
} }
systray.SetIcon(icon) systray.SetIcon(icon)
systray.SetTitle("Nex Gateway") systray.SetTooltip(appTooltip)
systray.SetTooltip("AI Gateway")
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开") mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
systray.AddSeparator() systray.AddSeparator()
@@ -279,17 +312,15 @@ func setupSystray(port int) {
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "") mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
mPort.Disable() mPort.Disable()
systray.AddSeparator() systray.AddSeparator()
mAbout := systray.AddMenuItem("关于", "")
systray.AddSeparator()
mQuit := systray.AddMenuItem("退出", "停止服务并退出") mQuit := systray.AddMenuItem("退出", "停止服务并退出")
go func() { go func() {
for { for {
select { select {
case <-mOpen.ClickedCh: case <-mOpen.ClickedCh:
openBrowser(fmt.Sprintf("http://localhost:%d", port)) if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
case <-mAbout.ClickedCh: zapLogger.Warn("打开浏览器失败", zap.Error(err))
showAbout() }
case <-mQuit.ClickedCh: case <-mQuit.ClickedCh:
doShutdown() doShutdown()
systray.Quit() systray.Quit()
@@ -308,7 +339,9 @@ func doShutdown() {
if server != nil { if server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
server.Shutdown(ctx) if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
zapLogger.Warn("关闭服务器失败", zap.Error(err))
}
} }
if shutdownCancel != nil { if shutdownCancel != nil {
@@ -346,8 +379,8 @@ func (s *SingletonLock) Lock() error {
return nil return nil
} }
func (s *SingletonLock) Unlock() { func (s *SingletonLock) Unlock() error {
s.flock.Unlock() return s.flock.Unlock()
} }
func openBrowser(url string) error { func openBrowser(url string) error {

View File

@@ -3,17 +3,59 @@
package main package main
import ( import (
"errors"
"syscall"
"testing" "testing"
) )
func TestMessageBoxW_WindowsOnly(t *testing.T) { func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
messageBox("测试标题", "测试消息", MB_ICONINFORMATION) t.Helper()
old := callMessageBoxW
callMessageBoxW = fn
t.Cleanup(func() {
callMessageBoxW = old
})
}
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
if err == nil {
t.Fatal("包含 NUL 字符时应该返回错误")
}
}
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 1, syscall.Errno(123)
})
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
}
}
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
err := messageBox("测试标题", "测试消息", mbIconInformation)
if !errors.Is(err, syscall.Errno(5)) {
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
}
} }
func TestShowError_WindowsBranch(t *testing.T) { func TestShowError_WindowsBranch(t *testing.T) {
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
return 0, syscall.Errno(5)
})
defer func() {
if recovered := recover(); recovered != nil {
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
}
}()
showError("测试错误", "这是一条测试错误消息") showError("测试错误", "这是一条测试错误消息")
} }
func TestShowAbout_WindowsBranch(t *testing.T) {
showAbout()
}

View File

@@ -0,0 +1,9 @@
package main
const (
appName = "Nex"
appTooltip = appName
appDescription = "AI Gateway - 统一的大模型 API 网关"
// #nosec G101 -- 项目官网地址不是凭据
appWebsite = "https://github.com/nex/gateway"
)

View File

@@ -0,0 +1,13 @@
package main
import "testing"
func TestDesktopMetadata(t *testing.T) {
if appName != "Nex" {
t.Fatalf("appName = %q, want %q", appName, "Nex")
}
if appTooltip != appName {
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
}
}

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"net" "net"
"net/http" "net/http"
"testing" "testing"
@@ -21,19 +22,12 @@ func TestCheckPortAvailable(t *testing.T) {
func TestCheckPortOccupied(t *testing.T) { func TestCheckPortOccupied(t *testing.T) {
port := 19827 port := 19827
listener, err := net.Listen("tcp", ":19827") listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
defer listener.Close() defer listener.Close()
go func() {
conn, err := listener.Accept()
if err == nil {
conn.Close()
}
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
err = checkPortAvailable(port) err = checkPortAvailable(port)
@@ -47,13 +41,19 @@ func TestCheckPortOccupied(t *testing.T) {
func TestCheckPortAvailableAfterClose(t *testing.T) { func TestCheckPortAvailableAfterClose(t *testing.T) {
port := 19828 port := 19828
listener, err := net.Listen("tcp", ":19828") listener, err := net.Listen("tcp", "127.0.0.1:19828")
if err != nil { if err != nil {
t.Fatalf("无法启动测试服务器: %v", err) t.Fatalf("无法启动测试服务器: %v", err)
} }
server := &http.Server{} server := &http.Server{ReadHeaderTimeout: time.Second}
go server.Serve(listener) defer server.Close()
go func() {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
t.Errorf("serve failed: %v", err)
}
}()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)

View File

@@ -0,0 +1,44 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"testing/fstest"
"nex/backend/internal/handler"
"github.com/gin-gonic/gin"
)
func TestSetupRoutes_VersionDoesNotFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
setupStaticFilesWithFS(r, fstest.MapFS{
"index.html": {Data: []byte("<html>fallback</html>")},
})
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
if contentType := w.Header().Get("Content-Type"); contentType == "text/html; charset=utf-8" {
t.Fatalf("版本接口不应返回 SPA fallback HTML")
}
var result map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
for _, key := range []string{"version", "commit", "build_time"} {
if result[key] == "" {
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
}
}
}

View File

@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
if err := lock.Lock(); err != nil { if err := lock.Lock(); err != nil {
t.Fatalf("首次加锁应成功,但返回错误: %v", err) t.Fatalf("首次加锁应成功,但返回错误: %v", err)
} }
defer lock.Unlock() defer func() {
if err := lock.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
} }
func TestSingletonLock_DuplicateLockFails(t *testing.T) { func TestSingletonLock_DuplicateLockFails(t *testing.T) {
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
defer lock1.Unlock() defer func() {
if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
}()
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
err := lock2.Lock() err := lock2.Lock()
if err == nil { if err == nil {
lock2.Unlock() if unlockErr := lock2.Unlock(); unlockErr != nil {
t.Fatalf("解锁失败: %v", unlockErr)
}
t.Fatal("重复加锁应失败,但返回 nil") t.Fatal("重复加锁应失败,但返回 nil")
} }
} }
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
if err := lock1.Lock(); err != nil { if err := lock1.Lock(); err != nil {
t.Fatalf("首次加锁应成功: %v", err) t.Fatalf("首次加锁应成功: %v", err)
} }
lock1.Unlock() if err := lock1.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
lock2 := NewSingletonLock(lockPath) lock2 := NewSingletonLock(lockPath)
if err := lock2.Lock(); err != nil { if err := lock2.Lock(); err != nil {
t.Fatalf("释放后重新加锁应成功: %v", err) t.Fatalf("释放后重新加锁应成功: %v", err)
} }
lock2.Unlock() if err := lock2.Unlock(); err != nil {
t.Fatalf("解锁失败: %v", err)
}
} }
func TestSingletonLock_UnlockWithoutLock(t *testing.T) { func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock")) lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
lock.Unlock() if err := lock.Unlock(); err != nil {
t.Fatalf("未加锁时解锁失败: %v", err)
}
} }

View File

@@ -1,73 +1,26 @@
package main package main
import ( import (
"io/fs" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"testing/fstest"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"nex/embedfs"
) )
func TestSetupStaticFiles(t *testing.T) { func TestSetupStaticFiles(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist") distFS, err := frontendDistFS()
if err != nil { if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err) t.Skipf("跳过测试: 前端资源未构建: %v", err)
return return
} }
getContentType := func(path string) string {
if strings.HasSuffix(path, ".js") {
return "application/javascript"
}
if strings.HasSuffix(path, ".css") {
return "text/css"
}
if strings.HasSuffix(path, ".svg") {
return "image/svg+xml"
}
return "application/octet-stream"
}
r := gin.New() r := gin.New()
r.GET("/assets/*filepath", func(c *gin.Context) { setupStaticFilesWithFS(r, distFS)
filepath := c.Param("filepath")
data, err := fs.ReadFile(distFS, "assets"+filepath)
if err != nil {
c.Status(404)
return
}
c.Data(200, getContentType(filepath), data)
})
r.GET("/favicon.svg", func(c *gin.Context) {
data, err := fs.ReadFile(distFS, "favicon.svg")
if err != nil {
c.Status(404)
return
}
c.Data(200, "image/svg+xml", data)
})
r.NoRoute(func(c *gin.Context) {
path := c.Request.URL.Path
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/health") {
c.JSON(404, gin.H{"error": "not found"})
return
}
data, err := fs.ReadFile(distFS, "index.html")
if err != nil {
c.Status(500)
return
}
c.Data(200, "text/html; charset=utf-8", data)
})
t.Run("API 404", func(t *testing.T) { t.Run("API 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil) req := httptest.NewRequest("GET", "/api/test", nil)
@@ -79,6 +32,32 @@ func TestSetupStaticFiles(t *testing.T) {
} }
}) })
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
req := httptest.NewRequest("GET", "/anthropic/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码 404, 实际 %d", w.Code)
}
if !strings.Contains(w.Body.String(), "not found") {
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
}
})
t.Run("SPA fallback", func(t *testing.T) { t.Run("SPA fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil) req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -121,3 +100,139 @@ func TestSetupStaticFiles(t *testing.T) {
t.Log("静态文件服务测试通过") t.Log("静态文件服务测试通过")
} }
func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupStaticFilesWithFS(r, fstest.MapFS{
"icon.png": {Data: []byte("png")},
"index.html": {Data: []byte("<html>fallback</html>")},
})
req := httptest.NewRequest("GET", "/icon.png", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
}
if w.Header().Get("Content-Type") != "image/png" {
t.Fatalf("期望 Content-Type image/png, 实际 %s", w.Header().Get("Content-Type"))
}
if w.Body.String() != "png" {
t.Fatalf("期望返回 PNG 内容,实际 %q", w.Body.String())
}
}
func TestWithProtocolAndStaticRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
distFS, err := frontendDistFS()
if err != nil {
t.Skipf("跳过测试: 前端资源未构建: %v", err)
return
}
r := gin.New()
var gotProtocol string
var gotPath string
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
gotProtocol = c.Param("protocol")
gotPath = c.Param("path")
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
}))
setupStaticFilesWithFS(r, distFS)
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "openai" {
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
}
if gotPath != "/v1/chat/completions" {
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
}
})
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if gotProtocol != "anthropic" {
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
}
if gotPath != "/v1/messages" {
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
}
})
t.Run("Static assets are not hijacked", func(t *testing.T) {
gotProtocol = ""
gotPath = ""
req := httptest.NewRequest("GET", "/assets/test.js", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if gotProtocol != "" || gotPath != "" {
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
if w.Code == http.StatusOK {
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
}
return
}
if w.Code != http.StatusNotFound {
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
}
})
t.Run("SPA path keeps fallback", func(t *testing.T) {
req := httptest.NewRequest("GET", "/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("期望状态码 200, 实际 %d", w.Code)
}
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
t.Errorf("期望返回 HTML实际 %s", w.Header().Get("Content-Type"))
}
})
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
req := httptest.NewRequest("GET", "/openai/unknown", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
}
if gotProtocol != "openai" || gotPath != "/unknown" {
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
}
})
}

View File

@@ -22,6 +22,7 @@ import (
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/internal/service" "nex/backend/internal/service"
"nex/backend/pkg/buildinfo"
pkgLogger "nex/backend/pkg/logger" pkgLogger "nex/backend/pkg/logger"
) )
@@ -44,7 +45,11 @@ func main() {
if err != nil { if err != nil {
minimalLogger.Fatal("初始化日志失败", zap.Error(err)) minimalLogger.Fatal("初始化日志失败", zap.Error(err))
} }
defer zapLogger.Sync() defer func() {
if err := zapLogger.Sync(); err != nil {
minimalLogger.Warn("同步日志失败", zap.Error(err))
}
}()
cfg.PrintSummary(zapLogger) cfg.PrintSummary(zapLogger)
@@ -88,6 +93,7 @@ func main() {
providerHandler := handler.NewProviderHandler(providerService) providerHandler := handler.NewProviderHandler(providerService)
modelHandler := handler.NewModelHandler(modelService) modelHandler := handler.NewModelHandler(modelService)
statsHandler := handler.NewStatsHandler(statsService) statsHandler := handler.NewStatsHandler(statsService)
versionHandler := handler.NewVersionHandler()
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
r := gin.New() r := gin.New()
@@ -97,7 +103,7 @@ func main() {
r.Use(middleware.Logging(zapLogger)) r.Use(middleware.Logging(zapLogger))
r.Use(middleware.CORS()) r.Use(middleware.CORS())
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler) setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
srv := &http.Server{ srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Server.Port), Addr: fmt.Sprintf(":%d", cfg.Server.Port),
@@ -107,7 +113,11 @@ func main() {
} }
go func() { go func() {
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr)) zapLogger.Info("AI Gateway 启动",
zap.String("addr", srv.Addr),
zap.String("version", buildinfo.Version()),
zap.String("commit", buildinfo.Commit()),
zap.String("build_time", buildinfo.BuildTime()))
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
zapLogger.Fatal("服务器启动失败", zap.Error(err)) zapLogger.Fatal("服务器启动失败", zap.Error(err))
} }
@@ -131,8 +141,9 @@ func main() {
zapLogger.Info("服务器已关闭") zapLogger.Info("服务器已关闭")
} }
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) { func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
r.Any("/:protocol/*path", proxyHandler.HandleProxy) r.Any("/:protocol/*path", proxyHandler.HandleProxy)
r.GET("/api/version", versionHandler.GetVersion)
providers := r.Group("/api/providers") providers := r.Group("/api/providers")
{ {

View File

@@ -0,0 +1,37 @@
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"nex/backend/internal/handler"
"github.com/gin-gonic/gin"
)
func TestSetupRoutes_Version(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
var result map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
for _, key := range []string{"version", "commit", "build_time"} {
if result[key] == "" {
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
}
}
}

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -58,7 +59,10 @@ type LogConfig struct {
// DefaultConfig returns default config values // DefaultConfig returns default config values
func DefaultConfig() *Config { func DefaultConfig() *Config {
// Use home dir for default paths // Use home dir for default paths
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
return &Config{ return &Config{
@@ -97,7 +101,7 @@ func GetConfigDir() (string, error) {
return "", err return "", err
} }
configDir := filepath.Join(homeDir, ".nex") configDir := filepath.Join(homeDir, ".nex")
if err := os.MkdirAll(configDir, 0755); err != nil { if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err return "", err
} }
return configDir, nil return configDir, nil
@@ -123,7 +127,10 @@ func GetConfigPath() (string, error) {
// setupDefaults 设置默认配置值 // setupDefaults 设置默认配置值
func setupDefaults(v *viper.Viper) { func setupDefaults(v *viper.Viper) {
homeDir, _ := os.UserHomeDir() homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
nexDir := filepath.Join(homeDir, ".nex") nexDir := filepath.Join(homeDir, ".nex")
v.SetDefault("server.port", 9826) v.SetDefault("server.port", 9826)
@@ -177,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
// 绑定所有 flag 到 viper // 绑定所有 flag 到 viper
// 注意:必须在设置默认值之后绑定 // 注意:必须在设置默认值之后绑定
v.BindPFlag("server.port", flagSet.Lookup("server-port")) bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout")) bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout")) bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
v.BindPFlag("database.driver", flagSet.Lookup("database-driver")) bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
v.BindPFlag("database.path", flagSet.Lookup("database-path")) bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
v.BindPFlag("database.host", flagSet.Lookup("database-host")) bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
v.BindPFlag("database.port", flagSet.Lookup("database-port")) bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
v.BindPFlag("database.user", flagSet.Lookup("database-user")) bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
v.BindPFlag("database.password", flagSet.Lookup("database-password")) bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname")) bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns")) bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns")) bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime")) bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
v.BindPFlag("log.level", flagSet.Lookup("log-level")) bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
v.BindPFlag("log.path", flagSet.Lookup("log-path")) bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size")) bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups")) bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age")) bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
v.BindPFlag("log.compress", flagSet.Lookup("log-compress")) bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
}
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
if err := v.BindPFlag(key, flag); err != nil {
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
}
} }
// setupEnv 绑定环境变量 // setupEnv 绑定环境变量
@@ -218,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
// 配置文件不存在,创建默认配置文件 // 配置文件不存在,创建默认配置文件
if err := v.SafeWriteConfig(); err != nil { writeErr := v.SafeWriteConfigAs(configPath)
// 忽略写入错误(可能目录已存在等) if writeErr == nil {
return nil return nil
} }
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
if errors.As(writeErr, &alreadyExistsErr) {
return nil
}
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
} }
return nil return nil
} }
@@ -246,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
setupFlags(v, flagSet) setupFlags(v, flagSet)
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数) // 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
flagSet.Parse(os.Args[1:]) if err := flagSet.Parse(os.Args[1:]); err != nil {
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
}
// 4. 获取配置文件路径(可能被 --config 参数覆盖) // 4. 获取配置文件路径(可能被 --config 参数覆盖)
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" { if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
@@ -295,11 +317,11 @@ func SaveConfig(cfg *Config) error {
// Ensure directory exists // Ensure directory exists
dir := filepath.Dir(configPath) dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return appErrors.Wrap(appErrors.ErrInternal, err) return appErrors.Wrap(appErrors.ErrInternal, err)
} }
return os.WriteFile(configPath, data, 0600) return os.WriteFile(configPath, data, 0o600)
} }
// Validate validates the config // Validate validates the config

View File

@@ -236,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
configPath := filepath.Join(dir, "config.yaml") configPath := filepath.Join(dir, "config.yaml")
data, err := yaml.Marshal(cfg) data, err := yaml.Marshal(cfg)
require.NoError(t, err) require.NoError(t, err)
err = os.WriteFile(configPath, data, 0644) err = os.WriteFile(configPath, data, 0o600)
require.NoError(t, err) require.NoError(t, err)
// 加载配置 // 加载配置

View File

@@ -6,15 +6,15 @@ import (
// Provider 供应商模型 // Provider 供应商模型
type Provider struct { type Provider struct {
ID string `gorm:"primaryKey" json:"id"` ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"` Name string `gorm:"not null" json:"name"`
APIKey string `gorm:"not null" json:"api_key"` APIKey string `gorm:"not null" json:"api_key"`
BaseURL string `gorm:"not null" json:"base_url"` BaseURL string `gorm:"not null" json:"base_url"`
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"` Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
Enabled bool `gorm:"default:true" json:"enabled"` Enabled bool `gorm:"default:true" json:"enabled"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"` Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
} }
// Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name) // Model 模型配置id 为 UUID 自动生成UNIQUE(provider_id, model_name)
@@ -47,4 +47,3 @@ func (Model) TableName() string {
func (UsageStats) TableName() string { func (UsageStats) TableName() string {
return "usage_stats" return "usage_stats"
} }

View File

@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Message: err.Message, Message: err.Message,
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat: case conversion.InterfaceTypeChat:
// Chat 响应必须有 model 字段,存在则改写,不存在则添加 // Chat 响应必须有 model 字段,存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
default: default:
return body, nil return body, nil

View File

@@ -2,6 +2,7 @@ package anthropic
import ( import (
"encoding/json" "encoding/json"
"errors"
"testing" "testing"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
@@ -48,14 +49,36 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
} }
} }
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
a := NewAdapter()
// docs/api_reference/anthropic defines messages and models under /v1.
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/v1/messages", conversion.InterfaceTypeChat},
{"/v1/models", conversion.InterfaceTypeModels},
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
{"/messages", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) { func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
nativePath string nativePath string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected string expected string
}{ }{
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"}, {"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"}, {"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
@@ -102,9 +125,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected bool expected bool
}{ }{
{"聊天", conversion.InterfaceTypeChat, true}, {"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true}, {"模型", conversion.InterfaceTypeModels, true},
@@ -141,8 +164,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
t.Run("解码嵌入请求", func(t *testing.T) { t.Run("解码嵌入请求", func(t *testing.T) {
_, err := a.DecodeEmbeddingRequest([]byte(`{}`)) _, err := a.DecodeEmbeddingRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -150,24 +173,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider) _, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.True(t, errors.As(err, &convErr))
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码嵌入响应", func(t *testing.T) { t.Run("解码嵌入响应", func(t *testing.T) {
_, err := a.DecodeEmbeddingResponse([]byte(`{}`)) _, err := a.DecodeEmbeddingResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码嵌入响应", func(t *testing.T) { t.Run("编码嵌入响应", func(t *testing.T) {
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{}) _, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }
@@ -178,8 +201,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
t.Run("解码重排序请求", func(t *testing.T) { t.Run("解码重排序请求", func(t *testing.T) {
_, err := a.DecodeRerankRequest([]byte(`{}`)) _, err := a.DecodeRerankRequest([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
@@ -187,24 +210,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3") provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider) _, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("解码重排序响应", func(t *testing.T) { t.Run("解码重排序响应", func(t *testing.T) {
_, err := a.DecodeRerankResponse([]byte(`{}`)) _, err := a.DecodeRerankResponse([]byte(`{}`))
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
t.Run("编码重排序响应", func(t *testing.T) { t.Run("编码重排序响应", func(t *testing.T) {
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{}) _, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
require.Error(t, err) require.Error(t, err)
convErr, ok := err.(*conversion.ConversionError) var convErr *conversion.ConversionError
require.True(t, ok) require.ErrorAs(t, err, &convErr)
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code) assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
}) })
} }

View File

@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
var canonicalMsgs []canonical.CanonicalMessage var canonicalMsgs []canonical.CanonicalMessage
for _, msg := range req.Messages { for _, msg := range req.Messages {
decoded := decodeMessage(msg) decoded, err := decodeMessage(msg)
if err != nil {
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
}
canonicalMsgs = append(canonicalMsgs, decoded...) canonicalMsgs = append(canonicalMsgs, decoded...)
} }
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
} }
// decodeMessage 解码 Anthropic 消息 // decodeMessage 解码 Anthropic 消息
func decodeMessage(msg Message) []canonical.CanonicalMessage { func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
switch msg.Role { switch msg.Role {
case "user": case "user":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
var toolResults []canonical.ContentBlock var toolResults []canonical.ContentBlock
var others []canonical.ContentBlock var others []canonical.ContentBlock
for _, b := range blocks { for _, b := range blocks {
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
if len(result) == 0 { if len(result) == 0 {
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}}) result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
} }
return result return result, nil
case "assistant": case "assistant":
blocks := decodeContentBlocks(msg.Content) blocks, err := decodeContentBlocks(msg.Content)
if err != nil {
return nil, err
}
if len(blocks) == 0 { if len(blocks) == 0 {
blocks = append(blocks, canonical.NewTextBlock("")) blocks = append(blocks, canonical.NewTextBlock(""))
} }
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}} return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
} }
return nil return nil, nil
} }
// decodeContentBlocks 解码内容块列表 // decodeContentBlocks 解码内容块列表
func decodeContentBlocks(content any) []canonical.ContentBlock { func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
switch v := content.(type) { switch v := content.(type) {
case string: case string:
return []canonical.ContentBlock{canonical.NewTextBlock(v)} return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
case []any: case []any:
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
block := decodeSingleContentBlock(m) block, err := decodeSingleContentBlock(m)
if err != nil {
return nil, err
}
if block != nil { if block != nil {
blocks = append(blocks, *block) blocks = append(blocks, *block)
} }
} }
} }
if len(blocks) > 0 { if len(blocks) > 0 {
return blocks return blocks, nil
} }
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
case nil: case nil:
return []canonical.ContentBlock{canonical.NewTextBlock("")} return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
default: default:
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))} return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
} }
} }
// decodeSingleContentBlock 解码单个内容块 // decodeSingleContentBlock 解码单个内容块
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock { func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
return nil, nil
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
return &canonical.ContentBlock{Type: "text", Text: text} if !ok {
text = ""
}
return &canonical.ContentBlock{Type: "text", Text: text}, nil
case "tool_use": case "tool_use":
id, _ := m["id"].(string) id, ok := m["id"].(string)
name, _ := m["name"].(string) if !ok {
input, _ := json.Marshal(m["input"]) id = ""
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input} }
name, ok := m["name"].(string)
if !ok {
name = ""
}
input, err := json.Marshal(m["input"])
if err != nil {
return nil, err
}
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
case "tool_result": case "tool_result":
toolUseID, _ := m["tool_use_id"].(string) toolUseID, ok := m["tool_use_id"].(string)
if !ok {
toolUseID = ""
}
isErr := false isErr := false
if ie, ok := m["is_error"].(bool); ok { if ie, ok := m["is_error"].(bool); ok {
isErr = ie isErr = ie
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
case string: case string:
content = json.RawMessage(fmt.Sprintf("%q", cv)) content = json.RawMessage(fmt.Sprintf("%q", cv))
default: default:
content, _ = json.Marshal(cv) encoded, err := json.Marshal(cv)
if err != nil {
return nil, err
}
content = encoded
} }
} else { } else {
content = json.RawMessage(`""`) content = json.RawMessage(`""`)
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
ToolUseID: toolUseID, ToolUseID: toolUseID,
Content: content, Content: content,
IsError: &isErr, IsError: &isErr,
} }, nil
case "thinking": case "thinking":
thinking, _ := m["thinking"].(string) thinking, ok := m["thinking"].(string)
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking} if !ok {
thinking = ""
}
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
case "redacted_thinking": case "redacted_thinking":
// 丢弃 // 丢弃
return nil return nil, nil
} }
return nil return nil, nil
} }
// decodeTools 解码工具定义 // decodeTools 解码工具定义
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "auto": case "auto":
return canonical.NewToolChoiceAuto() return canonical.NewToolChoiceAuto()
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
case "any": case "any":
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
case "tool": case "tool":
name, _ := v["name"].(string) name, ok := v["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
} }

View File

@@ -182,7 +182,7 @@ func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
result = append(result, m) result = append(result, m)
case "tool_result": case "tool_result":
m := map[string]any{ m := map[string]any{
"type": "tool_result", "type": "tool_result",
"tool_use_id": b.ToolUseID, "tool_use_id": b.ToolUseID,
} }
if b.Content != nil { if b.Content != nil {
@@ -335,11 +335,11 @@ func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
} }
result := map[string]any{ result := map[string]any{
"id": resp.ID, "id": resp.ID,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": resp.Model, "model": resp.Model,
"content": blocks, "content": blocks,
"stop_reason": sr, "stop_reason": sr,
"stop_sequence": nil, "stop_sequence": nil,
"usage": usage, "usage": usage,

View File

@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
assert.Equal(t, true, result["stream"]) assert.Equal(t, true, result["stream"])
assert.Equal(t, float64(1024), result["max_tokens"]) assert.Equal(t, float64(1024), result["max_tokens"])
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
} }
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
// tool 消息应被合并到相邻 user 消息 // tool 消息应被合并到相邻 user 消息
foundToolResult := false foundToolResult := false
for _, m := range msgs { for _, m := range msgs {
msgMap := m.(map[string]any) msgMap, ok := m.(map[string]any)
require.True(t, ok)
if msgMap["role"] == "user" { if msgMap["role"] == "user" {
content, ok := msgMap["content"].([]any) content, ok := msgMap["content"].([]any)
if ok { if ok {
for _, c := range content { for _, c := range content {
block := c.(map[string]any) block, ok := c.(map[string]any)
require.True(t, ok)
if block["type"] == "tool_result" { if block["type"] == "tool_result" {
foundToolResult = true foundToolResult = true
} }
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
firstMsg := msgs[0].(map[string]any) require.True(t, ok)
firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", firstMsg["role"]) assert.Equal(t, "user", firstMsg["role"])
} }
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "assistant", result["role"]) assert.Equal(t, "assistant", result["role"])
assert.Equal(t, "end_turn", result["stop_reason"]) assert.Equal(t, "end_turn", result["stop_reason"])
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", block["type"]) assert.Equal(t, "text", block["type"])
assert.Equal(t, "你好", block["text"]) assert.Equal(t, "你好", block["text"])
} }
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
data := result["data"].([]any) data, ok := result["data"].([]any)
require.True(t, ok)
assert.Len(t, data, 1) assert.Len(t, data, 1)
model := data[0].(map[string]any) model, ok := data[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "claude-3-opus", model["id"]) assert.Equal(t, "claude-3-opus", model["id"])
// created 应为 RFC3339 格式 // created 应为 RFC3339 格式
createdAt, ok := model["created_at"].(string) createdAt, ok := model["created_at"].(string)
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 1) assert.Len(t, msgs, 1)
userMsg := msgs[0].(map[string]any) userMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "user", userMsg["role"]) assert.Equal(t, "user", userMsg["role"])
content := userMsg["content"].([]any) content, ok := userMsg["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 2) assert.Len(t, content, 2)
} }
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, ok := result["usage"].(map[string]any)
require.True(t, ok)
_, hasReasoning := usage["reasoning_tokens"] _, hasReasoning := usage["reasoning_tokens"]
assert.False(t, hasReasoning) assert.False(t, hasReasoning)
} }
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
content := result["content"].([]any) content, ok := result["content"].([]any)
require.True(t, ok)
assert.Len(t, content, 1) assert.Len(t, content, 1)
block := content[0].(map[string]any) block, ok := content[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", block["type"]) assert.Equal(t, "tool_use", block["type"])
assert.Equal(t, "tool_1", block["id"]) assert.Equal(t, "tool_1", block["id"])
assert.Equal(t, "search", block["name"]) assert.Equal(t, "search", block["name"])

View File

@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
data := rawChunk data := rawChunk
if len(d.utf8Remainder) > 0 { if len(d.utf8Remainder) > 0 {
data = append(d.utf8Remainder, rawChunk...) data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
d.utf8Remainder = nil d.utf8Remainder = nil
} }
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
for _, line := range strings.Split(text, "\n") { for _, line := range strings.Split(text, "\n") {
line = strings.TrimRight(line, "\r") line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "event: ") { switch {
case strings.HasPrefix(line, "event: "):
eventType = strings.TrimPrefix(line, "event: ") eventType = strings.TrimPrefix(line, "event: ")
} else if strings.HasPrefix(line, "data: ") { case strings.HasPrefix(line, "data: "):
eventData = strings.TrimPrefix(line, "data: ") eventData = strings.TrimPrefix(line, "data: ")
if eventType != "" && eventData != "" { if eventType != "" && eventData != "" {
chunkEvents := d.processEvent(eventType, []byte(eventData)) chunkEvents := d.processEvent(eventType, []byte(eventData))
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
} }
eventType = "" eventType = ""
eventData = "" eventData = ""
} else if line == "" { case line == "":
// SSE 事件分隔符 continue
} }
} }
@@ -135,7 +136,7 @@ func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalSt
// processContentBlockStart 处理内容块开始事件 // processContentBlockStart 处理内容块开始事件
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent { func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
var raw struct { var raw struct {
Index int `json:"index"` Index int `json:"index"`
ContentBlock struct { ContentBlock struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`

View File

@@ -47,23 +47,23 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
checkValue string checkValue string
}{ }{
{ {
name: "text_delta", name: "text_delta",
deltaType: "text_delta", deltaType: "text_delta",
deltaData: map[string]any{"type": "text_delta", "text": "你好"}, deltaData: map[string]any{"type": "text_delta", "text": "你好"},
checkField: "text", checkField: "text",
checkValue: "你好", checkValue: "你好",
}, },
{ {
name: "input_json_delta", name: "input_json_delta",
deltaType: "input_json_delta", deltaType: "input_json_delta",
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"}, deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
checkField: "partial_json", checkField: "partial_json",
checkValue: "{\"key\":", checkValue: "{\"key\":",
}, },
{ {
name: "thinking_delta", name: "thinking_delta",
deltaType: "thinking_delta", deltaType: "thinking_delta",
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"}, deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
checkField: "thinking", checkField: "thinking",
checkValue: "思考中", checkValue: "思考中",
}, },
@@ -74,7 +74,7 @@ func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
payload := map[string]any{ payload := map[string]any{
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
"delta": tt.deltaData, "delta": tt.deltaData,
} }
raw := makeAnthropicEvent("content_block_delta", payload) raw := makeAnthropicEvent("content_block_delta", payload)
@@ -298,7 +298,7 @@ func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
"type": "content_block_start", "type": "content_block_start",
"index": 3, "index": 3,
"content_block": map[string]any{ "content_block": map[string]any{
"type": "web_search_tool_result", "type": "web_search_tool_result",
"tool_use_id": "search_1", "tool_use_id": "search_1",
}, },
} }
@@ -331,8 +331,8 @@ func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": 0,
"delta": map[string]any{ "delta": map[string]any{
"type": "citations_delta", "type": "citations_delta",
"citation": map[string]any{"title": "ref1"}, "citation": map[string]any{"title": "ref1"},
}, },
} }
raw := makeAnthropicEvent("content_block_delta", payload) raw := makeAnthropicEvent("content_block_delta", payload)
@@ -466,7 +466,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
}, },
} }
deltaPayload1 := map[string]any{ deltaPayload1 := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"}, "delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 25}, "usage": map[string]any{"output_tokens": 25},
} }
@@ -478,7 +478,7 @@ func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
assert.Equal(t, 25, events[0].Usage.OutputTokens) assert.Equal(t, 25, events[0].Usage.OutputTokens)
deltaPayload2 := map[string]any{ deltaPayload2 := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn"}, "delta": map[string]any{"stop_reason": "end_turn"},
"usage": map[string]any{"output_tokens": 30}, "usage": map[string]any{"output_tokens": 30},
} }

View File

@@ -50,16 +50,24 @@ func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent)
} }
if event.Message != nil { if event.Message != nil {
msg := map[string]any{ msg := map[string]any{
"id": event.Message.ID, "id": event.Message.ID,
"model": event.Message.Model, "type": "message",
"role": "assistant", "role": "assistant",
"content": []any{},
"model": event.Message.Model,
"stop_reason": nil,
"stop_sequence": nil,
} }
if event.Message.Usage != nil { if event.Message.Usage != nil {
usage := map[string]any{ msg["usage"] = map[string]any{
"input_tokens": event.Message.Usage.InputTokens, "input_tokens": event.Message.Usage.InputTokens,
"output_tokens": event.Message.Usage.OutputTokens, "output_tokens": event.Message.Usage.OutputTokens,
} }
msg["usage"] = usage } else {
msg["usage"] = map[string]any{
"input_tokens": 0,
"output_tokens": 0,
}
} }
payload["message"] = msg payload["message"] = msg
} }
@@ -147,6 +155,10 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
payload["usage"] = map[string]any{ payload["usage"] = map[string]any{
"output_tokens": event.Usage.OutputTokens, "output_tokens": event.Usage.OutputTokens,
} }
} else {
payload["usage"] = map[string]any{
"output_tokens": 0,
}
} }
return e.marshalEvent("message_delta", payload) return e.marshalEvent("message_delta", payload)
} }

View File

@@ -21,8 +21,55 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
s := string(chunks[0]) s := string(chunks[0])
assert.True(t, strings.HasPrefix(s, "event: message_start\n")) assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
assert.Contains(t, s, "data: ") assert.Contains(t, s, "data: ")
assert.Contains(t, s, "msg_1")
assert.Contains(t, s, "claude-3") var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "msg_1", msg["id"])
assert.Equal(t, "message", msg["type"])
assert.Equal(t, "assistant", msg["role"])
assert.Equal(t, []any{}, msg["content"])
assert.Equal(t, "claude-3", msg["model"])
assert.Nil(t, msg["stop_reason"])
assert.Nil(t, msg["stop_sequence"])
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(0), usage["input_tokens"])
assert.Equal(t, float64(0), usage["output_tokens"])
}
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
e := NewStreamEncoder()
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
chunks := e.EncodeEvent(event)
require.Len(t, chunks, 1)
s := string(chunks[0])
var payload map[string]any
lines := strings.Split(s, "\n")
for _, l := range lines {
if strings.HasPrefix(l, "data: ") {
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
break
}
}
msg, ok := payload["message"].(map[string]any)
require.True(t, ok)
usage, okU := msg["usage"].(map[string]any)
require.True(t, okU)
assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(50), usage["output_tokens"])
} }
func TestStreamEncoder_ContentBlockDelta(t *testing.T) { func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
@@ -80,7 +127,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "text", cb["type"]) assert.Equal(t, "text", cb["type"])
} }
@@ -107,7 +155,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "tool_use", cb["type"]) assert.Equal(t, "tool_use", cb["type"])
assert.Equal(t, "toolu_1", cb["id"]) assert.Equal(t, "toolu_1", cb["id"])
assert.Equal(t, "search", cb["name"]) assert.Equal(t, "search", cb["name"])
@@ -131,7 +180,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
break break
} }
} }
cb := payload["content_block"].(map[string]any) cb, ok := payload["content_block"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "thinking", cb["type"]) assert.Equal(t, "thinking", cb["type"])
} }
@@ -173,8 +223,13 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
break break
} }
} }
delta := payload["delta"].(map[string]any) delta, okd := payload["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "end_turn", delta["stop_reason"]) assert.Equal(t, "end_turn", delta["stop_reason"])
usage, oku := payload["usage"].(map[string]any)
require.True(t, oku, "message_delta SHALL always include usage")
assert.Equal(t, float64(0), usage["output_tokens"])
} }
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) { func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
@@ -199,7 +254,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
break break
} }
} }
u := payload["usage"].(map[string]any) u, oku := payload["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(88), u["output_tokens"]) assert.Equal(t, float64(88), u["output_tokens"])
} }

View File

@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
} }
func TestDecodeContentBlocks_Nil(t *testing.T) { func TestDecodeContentBlocks_Nil(t *testing.T) {
blocks := decodeContentBlocks(nil) blocks, err := decodeContentBlocks(nil)
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "", blocks[0].Text) assert.Equal(t, "", blocks[0].Text)
} }
func TestDecodeContentBlocks_String(t *testing.T) { func TestDecodeContentBlocks_String(t *testing.T) {
blocks := decodeContentBlocks("hello") blocks, err := decodeContentBlocks("hello")
require.NoError(t, err)
assert.Len(t, blocks, 1) assert.Len(t, blocks, 1)
assert.Equal(t, "hello", blocks[0].Text) assert.Equal(t, "hello", blocks[0].Text)
} }
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := encodeToolChoice(tt.choice) result := encodeToolChoice(tt.choice)
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"]) r, ok := result.(map[string]any)
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"]) require.True(t, ok)
assert.Equal(t, tt.want["type"], r["type"])
assert.Equal(t, tt.want["name"], r["name"])
}) })
} }
} }
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
tools := result["tools"].([]any) tools, okt := result["tools"].([]any)
require.True(t, okt)
assert.Len(t, tools, 1) assert.Len(t, tools, 1)
tool := tools[0].(map[string]any) tool, okt2 := tools[0].(map[string]any)
require.True(t, okt2)
assert.Equal(t, "search", tool["name"]) assert.Equal(t, "search", tool["name"])
assert.Equal(t, "Search things", tool["description"]) assert.Equal(t, "Search things", tool["description"])
tc := result["tool_choice"].(map[string]any) tc, oktc := result["tool_choice"].(map[string]any)
require.True(t, oktc)
assert.Equal(t, "auto", tc["type"]) assert.Equal(t, "auto", tc["type"])
} }
@@ -354,9 +361,9 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{ Usage: canonical.CanonicalUsage{
InputTokens: 100, InputTokens: 100,
OutputTokens: 50, OutputTokens: 50,
CacheReadTokens: &cacheRead, CacheReadTokens: &cacheRead,
CacheCreationTokens: &cacheCreation, CacheCreationTokens: &cacheCreation,
}, },
} }
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["input_tokens"]) assert.Equal(t, float64(100), usage["input_tokens"])
assert.Equal(t, float64(30), usage["cache_read_input_tokens"]) assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"]) assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])

View File

@@ -6,22 +6,22 @@ import (
// MessagesRequest Anthropic Messages 请求 // MessagesRequest Anthropic Messages 请求
type MessagesRequest struct { type MessagesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
System any `json:"system,omitempty"` System any `json:"system,omitempty"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"` TopK *int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Metadata *RequestMetadata `json:"metadata,omitempty"` Metadata *RequestMetadata `json:"metadata,omitempty"`
Thinking *ThinkingConfig `json:"thinking,omitempty"` Thinking *ThinkingConfig `json:"thinking,omitempty"`
OutputConfig *OutputConfig `json:"output_config,omitempty"` OutputConfig *OutputConfig `json:"output_config,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"` DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
Container any `json:"container,omitempty"` Container any `json:"container,omitempty"`
} }
// RequestMetadata 请求元数据 // RequestMetadata 请求元数据
@@ -122,8 +122,8 @@ type ContentBlock struct {
// ResponseUsage 响应用量 // ResponseUsage 响应用量
type ResponseUsage struct { type ResponseUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"` CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"` CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
} }

View File

@@ -38,8 +38,8 @@ type CanonicalEmbeddingResponse struct {
// EmbeddingData 嵌入数据项 // EmbeddingData 嵌入数据项
type EmbeddingData struct { type EmbeddingData struct {
Index int `json:"index"` Index int `json:"index"`
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串 Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
} }
// EmbeddingUsage 嵌入用量 // EmbeddingUsage 嵌入用量

View File

@@ -18,17 +18,17 @@ const (
type DeltaType string type DeltaType string
const ( const (
DeltaTypeText DeltaType = "text_delta" DeltaTypeText DeltaType = "text_delta"
DeltaTypeInputJSON DeltaType = "input_json_delta" DeltaTypeInputJSON DeltaType = "input_json_delta"
DeltaTypeThinking DeltaType = "thinking_delta" DeltaTypeThinking DeltaType = "thinking_delta"
) )
// StreamDelta 流式增量联合体 // StreamDelta 流式增量联合体
type StreamDelta struct { type StreamDelta struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
PartialJSON string `json:"partial_json,omitempty"` PartialJSON string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"` Thinking string `json:"thinking,omitempty"`
} }
// StreamContentBlock 流式内容块联合体 // StreamContentBlock 流式内容块联合体
@@ -48,12 +48,12 @@ type CanonicalStreamEvent struct {
Message *StreamMessage `json:"message,omitempty"` Message *StreamMessage `json:"message,omitempty"`
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent // ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ContentBlock *StreamContentBlock `json:"content_block,omitempty"` ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
Delta *StreamDelta `json:"delta,omitempty"` Delta *StreamDelta `json:"delta,omitempty"`
// MessageDeltaEvent // MessageDeltaEvent
StopReason *StopReason `json:"stop_reason,omitempty"` StopReason *StopReason `json:"stop_reason,omitempty"`
Usage *CanonicalUsage `json:"usage,omitempty"` Usage *CanonicalUsage `json:"usage,omitempty"`
// ErrorEvent // ErrorEvent

View File

@@ -40,8 +40,8 @@ type ContentBlock struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
// ToolUseBlock // ToolUseBlock
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"` Input json.RawMessage `json:"input,omitempty"`
// ToolResultBlock // ToolResultBlock
@@ -138,43 +138,43 @@ type ThinkingConfig struct {
// OutputFormat 输出格式联合体 // OutputFormat 输出格式联合体
type OutputFormat struct { type OutputFormat struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Schema json.RawMessage `json:"schema,omitempty"` Schema json.RawMessage `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"` Strict *bool `json:"strict,omitempty"`
} }
// CanonicalRequest 规范请求 // CanonicalRequest 规范请求
type CanonicalRequest struct { type CanonicalRequest struct {
Model string `json:"model"` Model string `json:"model"`
System any `json:"system,omitempty"` // nil, string, or []SystemBlock System any `json:"system,omitempty"` // nil, string, or []SystemBlock
Messages []CanonicalMessage `json:"messages"` Messages []CanonicalMessage `json:"messages"`
Tools []CanonicalTool `json:"tools,omitempty"` Tools []CanonicalTool `json:"tools,omitempty"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"` ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Parameters RequestParameters `json:"parameters"` Parameters RequestParameters `json:"parameters"`
Thinking *ThinkingConfig `json:"thinking,omitempty"` Thinking *ThinkingConfig `json:"thinking,omitempty"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
OutputFormat *OutputFormat `json:"output_format,omitempty"` OutputFormat *OutputFormat `json:"output_format,omitempty"`
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"` ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
} }
// CanonicalUsage 规范用量 // CanonicalUsage 规范用量
type CanonicalUsage struct { type CanonicalUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheReadTokens *int `json:"cache_read_tokens,omitempty"` CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"` CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
ReasoningTokens *int `json:"reasoning_tokens,omitempty"` ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
} }
// CanonicalResponse 规范响应 // CanonicalResponse 规范响应
type CanonicalResponse struct { type CanonicalResponse struct {
ID string `json:"id"` ID string `json:"id"`
Model string `json:"model"` Model string `json:"model"`
Content []ContentBlock `json:"content"` Content []ContentBlock `json:"content"`
StopReason *StopReason `json:"stop_reason,omitempty"` StopReason *StopReason `json:"stop_reason,omitempty"`
Usage CanonicalUsage `json:"usage"` Usage CanonicalUsage `json:"usage"`
} }
// GetSystemString 获取系统消息字符串 // GetSystemString 获取系统消息字符串

View File

@@ -10,9 +10,9 @@ import (
func TestGetSystemString(t *testing.T) { func TestGetSystemString(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
system any system any
want string want string
}{ }{
{"string", "hello", "hello"}, {"string", "hello", "hello"},
{"nil", nil, ""}, {"nil", nil, ""},
@@ -97,11 +97,11 @@ func TestCanonicalRequest_RoundTrip(t *testing.T) {
func TestCanonicalResponse_RoundTrip(t *testing.T) { func TestCanonicalResponse_RoundTrip(t *testing.T) {
sr := StopReasonEndTurn sr := StopReasonEndTurn
resp := &CanonicalResponse{ resp := &CanonicalResponse{
ID: "resp-1", ID: "resp-1",
Model: "gpt-4", Model: "gpt-4",
Content: []ContentBlock{NewTextBlock("hello")}, Content: []ContentBlock{NewTextBlock("hello")},
StopReason: &sr, StopReason: &sr,
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5}, Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
} }
data, err := json.Marshal(resp) data, err := json.Marshal(resp)

View File

@@ -3,11 +3,13 @@ package conversion
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/conversion/canonical"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -71,7 +73,7 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
// ConvertHttpRequest 转换 HTTP 请求 // ConvertHttpRequest 转换 HTTP 请求
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) { func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
nativePath := spec.URL nativePath, rawQuery := splitRequestPath(spec.URL)
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
providerAdapter, err := e.registry.Get(providerProtocol) providerAdapter, err := e.registry.Get(providerProtocol)
@@ -96,8 +98,11 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
} }
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + nativePath, URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method, Method: spec.Method,
Headers: providerAdapter.BuildHeaders(provider), Headers: providerAdapter.BuildHeaders(provider),
Body: rewrittenBody, Body: rewrittenBody,
@@ -114,7 +119,8 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
interfaceType := clientAdapter.DetectInterfaceType(nativePath) interfaceType := clientAdapter.DetectInterfaceType(nativePath)
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType) providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
providerURL = appendRawQuery(providerURL, rawQuery)
providerHeaders := providerAdapter.BuildHeaders(provider) providerHeaders := providerAdapter.BuildHeaders(provider)
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body) providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
if err != nil { if err != nil {
@@ -122,7 +128,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
} }
return &HTTPRequestSpec{ return &HTTPRequestSpec{
URL: provider.BaseURL + providerUrl, URL: joinBaseURL(provider.BaseURL, providerURL),
Method: spec.Method, Method: spec.Method,
Headers: providerHeaders, Headers: providerHeaders,
Body: providerBody, Body: providerBody,
@@ -134,24 +140,21 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议时最小化改写 model 字段 // Smart Passthrough: 同协议时最小化改写 model 字段
if modelOverride != "" && len(spec.Body) > 0 { if modelOverride != "" && len(spec.Body) > 0 {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return &spec, nil rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
} if rewriteErr != nil {
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
if err != nil {
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体", e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
zap.Error(err), zap.Error(rewriteErr),
zap.String("interface", string(interfaceType))) zap.String("interface", string(interfaceType)))
return &spec, nil } else {
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
}
} }
return &HTTPResponseSpec{
StatusCode: spec.StatusCode,
Headers: spec.Headers,
Body: rewrittenBody,
}, nil
} }
return &spec, nil return &spec, nil
} }
@@ -182,11 +185,10 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
if e.IsPassthrough(clientProtocol, providerProtocol) { if e.IsPassthrough(clientProtocol, providerProtocol) {
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段 // Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
if modelOverride != "" { if modelOverride != "" {
adapter, err := e.registry.Get(clientProtocol) adapter, getErr := e.registry.Get(clientProtocol)
if err != nil { if getErr == nil {
return NewPassthroughStreamConverter(), nil return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
} }
return NewPassthroughStreamConverter(), nil return NewPassthroughStreamConverter(), nil
} }
@@ -201,9 +203,9 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
} }
ctx := ConversionContext{ ctx := ConversionContext{
ConversionID: uuid.New().String(), ConversionID: uuid.New().String(),
InterfaceType: InterfaceTypeChat, InterfaceType: interfaceType,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
return NewCanonicalStreamConverterWithMiddleware( return NewCanonicalStreamConverterWithMiddleware(
@@ -272,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
canonicalReq, err := clientAdapter.DecodeRequest(body) canonicalReq, err := clientAdapter.DecodeRequest(body)
if err != nil { if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err) return nil, NewRequestJSONParseError("解码请求失败", err)
} }
ctx := NewConversionContext(InterfaceTypeChat) ctx := NewConversionContext(InterfaceTypeChat)
@@ -280,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
if err != nil { if err != nil {
return nil, err return nil, err
} }
if containsUnsupportedMultimodal(canonicalReq) {
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
}
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider) encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
if err != nil { if err != nil {
@@ -291,7 +296,7 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
canonicalResp, err := providerAdapter.DecodeResponse(body) canonicalResp, err := providerAdapter.DecodeResponse(body)
if err != nil { if err != nil {
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err) return nil, NewResponseJSONParseError("解码响应失败", err)
} }
if modelOverride != "" { if modelOverride != "" {
canonicalResp.Model = modelOverride canonicalResp.Model = modelOverride
@@ -306,7 +311,7 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
models, err := providerAdapter.DecodeModelsResponse(body) models, err := providerAdapter.DecodeModelsResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelsResponse(models) encoded, err := clientAdapter.EncodeModelsResponse(models)
@@ -320,12 +325,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) { func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
info, err := providerAdapter.DecodeModelInfoResponse(body) info, err := providerAdapter.DecodeModelInfoResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
encoded, err := clientAdapter.EncodeModelInfoResponse(info) encoded, err := clientAdapter.EncodeModelInfoResponse(info)
if err != nil { if err != nil {
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
return encoded, nil return encoded, nil
@@ -334,7 +339,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeEmbeddingRequest(body) req, err := clientAdapter.DecodeEmbeddingRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeEmbeddingRequest(req, provider) return providerAdapter.EncodeEmbeddingRequest(req, provider)
@@ -343,7 +348,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeEmbeddingResponse(body) resp, err := providerAdapter.DecodeEmbeddingResponse(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error())) e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
return body, nil return body, nil
} }
if modelOverride != "" { if modelOverride != "" {
@@ -355,21 +360,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) { func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
req, err := clientAdapter.DecodeRerankRequest(body) req, err := clientAdapter.DecodeRerankRequest(body)
if err != nil { if err != nil {
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error())) e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
return body, nil return body, nil
} }
return providerAdapter.EncodeRerankRequest(req, provider) return providerAdapter.EncodeRerankRequest(req, provider)
} }
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) { func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
resp, err := providerAdapter.DecodeRerankResponse(body) resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
if err != nil { if decodeErr == nil {
return body, nil if modelOverride != "" {
resp.Model = modelOverride
}
return clientAdapter.EncodeRerankResponse(resp)
} }
if modelOverride != "" {
resp.Model = modelOverride return body, nil
}
return clientAdapter.EncodeRerankResponse(resp)
} }
// DetectInterfaceType 检测接口类型 // DetectInterfaceType 检测接口类型
@@ -378,6 +384,7 @@ func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string
if err != nil { if err != nil {
return InterfaceTypePassthrough, err return InterfaceTypePassthrough, err
} }
nativePath, _ = splitRequestPath(nativePath)
return adapter.DetectInterfaceType(nativePath), nil return adapter.DetectInterfaceType(nativePath), nil
} }
@@ -391,9 +398,56 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
"type": "internal_error", "type": "internal_error",
}, },
} }
body, _ := json.Marshal(fallback) body, marshalErr := json.Marshal(fallback)
return body, 500, nil if marshalErr == nil {
return body, 500, nil
}
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
} }
body, statusCode := adapter.EncodeError(err) body, statusCode := adapter.EncodeError(err)
return body, statusCode, nil return body, statusCode, nil
} }
func splitRequestPath(rawPath string) (string, string) {
path, query, found := strings.Cut(rawPath, "?")
if !found {
return rawPath, ""
}
return path, query
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
if strings.Contains(path, "?") {
return path + "&" + rawQuery
}
return path + "?" + rawQuery
}
func joinBaseURL(baseURL, path string) string {
if baseURL == "" {
return path
}
if path == "" {
return baseURL
}
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
if req == nil {
return false
}
for _, msg := range req.Messages {
for _, block := range msg.Content {
switch block.Type {
case "image", "audio", "video", "file":
return true
}
}
}
return false
}

View File

@@ -0,0 +1,63 @@
package conversion_test
import (
"testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
tests := []struct {
name string
adapter conversion.ProtocolAdapter
clientProtocol string
providerProtocol string
baseURL string
nativePath string
expectedURL string
body []byte
}{
{
name: "openai base url includes version path",
adapter: openai.NewAdapter(),
clientProtocol: "openai",
providerProtocol: "openai",
baseURL: "http://example.com/v1",
nativePath: "/chat/completions",
expectedURL: "http://example.com/v1/chat/completions",
body: []byte(`{"model":"gpt-4","messages":[]}`),
},
{
name: "anthropic native path keeps v1",
adapter: anthropic.NewAdapter(),
clientProtocol: "anthropic",
providerProtocol: "anthropic",
baseURL: "http://example.com",
nativePath: "/v1/messages",
expectedURL: "http://example.com/v1/messages",
body: []byte(`{"model":"claude","messages":[]}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
registry := conversion.NewMemoryRegistry()
engine := conversion.NewConversionEngine(registry, zap.NewNop())
require.NoError(t, registry.Register(tt.adapter))
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
URL: tt.nativePath,
Method: "POST",
Body: tt.body,
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
require.NoError(t, err)
assert.Equal(t, tt.expectedURL, out.URL)
})
}
}

View File

@@ -2,6 +2,7 @@ package conversion
import ( import (
"encoding/json" "encoding/json"
"strings"
"testing" "testing"
"nex/backend/internal/conversion/canonical" "nex/backend/internal/conversion/canonical"
@@ -38,8 +39,8 @@ func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
} }
} }
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName } func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" } func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough } func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType { func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
@@ -190,14 +191,16 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
// noopStreamDecoder 空流式解码器 // noopStreamDecoder 空流式解码器
type noopStreamDecoder struct{} type noopStreamDecoder struct{}
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil } func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil } return nil
}
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
// noopStreamEncoder 空流式编码器 // noopStreamEncoder 空流式编码器
type noopStreamEncoder struct{} type noopStreamEncoder struct{}
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil } func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
func (e *noopStreamEncoder) Flush() [][]byte { return nil } func (e *noopStreamEncoder) Flush() [][]byte { return nil }
// ============ 测试用例 ============ // ============ 测试用例 ============
@@ -285,19 +288,33 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
func TestConvertHttpRequest_Passthrough(t *testing.T) { func TestConvertHttpRequest_Passthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop()) engine := NewConversionEngine(registry, zap.NewNop())
_ = engine.RegisterAdapter(newMockAdapter("openai", true)) openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
}
_ = engine.RegisterAdapter(openaiAdapter)
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4") provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
spec := HTTPRequestSpec{ spec := HTTPRequestSpec{
URL: "/chat/completions", URL: "/v1/chat/completions",
Method: "POST", Method: "POST",
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`), Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
} }
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider) result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL) assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
assert.Equal(t, spec.Body, result.Body) assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
} }
func TestConvertHttpRequest_CrossProtocol(t *testing.T) { func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
@@ -332,6 +349,77 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
assert.NotNil(t, result.Body) assert.NotNil(t, result.Body)
} }
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop())
openaiAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("openai", true),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/chat/completions"
}
return nativePath
},
}
openaiAdapter.ifaceType = InterfaceTypeChat
openaiAdapter.supportsIface[InterfaceTypeChat] = true
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
return []byte(`{"model":"` + newModel + `"}`), nil
}
require.NoError(t, registry.Register(openaiAdapter))
anthropicAdapter := &buildURLMockAdapter{
mockProtocolAdapter: newMockAdapter("anthropic", false),
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
if interfaceType == InterfaceTypeChat {
return "/v1/messages"
}
return nativePath
},
}
anthropicAdapter.ifaceType = InterfaceTypeChat
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
require.NoError(t, registry.Register(anthropicAdapter))
t.Run("OpenAI to Anthropic", func(t *testing.T) {
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
spec := HTTPRequestSpec{
URL: "/v1/chat/completions",
Method: "POST",
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
}
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
})
t.Run("Anthropic to OpenAI", func(t *testing.T) {
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
spec := HTTPRequestSpec{
URL: "/v1/messages",
Method: "POST",
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
}
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
require.NoError(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
})
}
type buildURLMockAdapter struct {
*mockProtocolAdapter
buildURLFn func(string, InterfaceType) string
}
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
if m.buildURLFn != nil {
return m.buildURLFn(nativePath, interfaceType)
}
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
}
func TestConvertHttpResponse_Passthrough(t *testing.T) { func TestConvertHttpResponse_Passthrough(t *testing.T) {
registry := NewMemoryRegistry() registry := NewMemoryRegistry()
engine := NewConversionEngine(registry, zap.NewNop()) engine := NewConversionEngine(registry, zap.NewNop())
@@ -496,12 +584,13 @@ func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
_, ok := converter.(*SmartPassthroughStreamConverter) _, ok := converter.(*SmartPassthroughStreamConverter)
assert.True(t, ok) assert.True(t, ok)
// 验证 chunk 改写 // 验证 SSE frame 中的 data JSON 被改写
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`)) chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
require.Len(t, chunks, 1) require.Len(t, chunks, 1)
var resp map[string]interface{} var resp map[string]interface{}
require.NoError(t, json.Unmarshal(chunks[0], &resp)) payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
assert.Equal(t, "openai/gpt-4", resp["model"]) assert.Equal(t, "openai/gpt-4", resp["model"])
} }
@@ -615,6 +704,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
} }
return nil return nil
} }
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent { func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
if d.flushFn != nil { if d.flushFn != nil {
return d.flushFn() return d.flushFn()
@@ -634,6 +724,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
} }
return nil return nil
} }
func (e *engineTestStreamEncoder) Flush() [][]byte { func (e *engineTestStreamEncoder) Flush() [][]byte {
if e.flushFn != nil { if e.flushFn != nil {
return e.flushFn() return e.flushFn()

View File

@@ -6,17 +6,24 @@ import "fmt"
type ErrorCode string type ErrorCode string
const ( const (
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT" ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD" ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE" ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE" ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR" ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR" ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR" ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR" ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION" ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE" ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED" ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
)
const (
ErrorDetailPhase = "phase"
ErrorPhaseRequest = "request"
ErrorPhaseResponse = "response"
) )
// ConversionError 协议转换错误 // ConversionError 协议转换错误
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
} }
} }
// NewRequestJSONParseError 创建请求 JSON 解析错误。
func NewRequestJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
WithCause(cause)
}
// NewResponseJSONParseError 创建响应 JSON 解析错误。
func NewResponseJSONParseError(message string, cause error) *ConversionError {
return NewConversionError(ErrorCodeJSONParseError, message).
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
WithCause(cause)
}
// WithClientProtocol 设置客户端协议 // WithClientProtocol 设置客户端协议
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError { func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
e.ClientProtocol = protocol e.ClientProtocol = protocol

View File

@@ -4,10 +4,10 @@ package conversion
type InterfaceType string type InterfaceType string
const ( const (
InterfaceTypeChat InterfaceType = "CHAT" InterfaceTypeChat InterfaceType = "CHAT"
InterfaceTypeModels InterfaceType = "MODELS" InterfaceTypeModels InterfaceType = "MODELS"
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO" InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS" InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
InterfaceTypeRerank InterfaceType = "RERANK" InterfaceTypeRerank InterfaceType = "RERANK"
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH" InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
) )

View File

@@ -29,27 +29,27 @@ func (a *Adapter) SupportsPassthrough() bool { return true }
// DetectInterfaceType 根据路径检测接口类型 // DetectInterfaceType 根据路径检测接口类型
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType { func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
switch { switch {
case nativePath == "/chat/completions": case nativePath == "/v1/chat/completions":
return conversion.InterfaceTypeChat return conversion.InterfaceTypeChat
case nativePath == "/models": case nativePath == "/v1/models":
return conversion.InterfaceTypeModels return conversion.InterfaceTypeModels
case isModelInfoPath(nativePath): case isModelInfoPath(nativePath):
return conversion.InterfaceTypeModelInfo return conversion.InterfaceTypeModelInfo
case nativePath == "/embeddings": case nativePath == "/v1/embeddings":
return conversion.InterfaceTypeEmbeddings return conversion.InterfaceTypeEmbeddings
case nativePath == "/rerank": case nativePath == "/v1/rerank":
return conversion.InterfaceTypeRerank return conversion.InterfaceTypeRerank
default: default:
return conversion.InterfaceTypePassthrough return conversion.InterfaceTypePassthrough
} }
} }
// isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 / // isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /
func isModelInfoPath(path string) bool { func isModelInfoPath(path string) bool {
if !strings.HasPrefix(path, "/models/") { if !strings.HasPrefix(path, "/v1/models/") {
return false return false
} }
suffix := path[len("/models/"):] suffix := path[len("/v1/models/"):]
return suffix != "" return suffix != ""
} }
@@ -60,6 +60,11 @@ func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.Interface
return "/chat/completions" return "/chat/completions"
case conversion.InterfaceTypeModels: case conversion.InterfaceTypeModels:
return "/models" return "/models"
case conversion.InterfaceTypeModelInfo:
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
return "/models/" + modelID
}
return nativePath
case conversion.InterfaceTypeEmbeddings: case conversion.InterfaceTypeEmbeddings:
return "/embeddings" return "/embeddings"
case conversion.InterfaceTypeRerank: case conversion.InterfaceTypeRerank:
@@ -138,7 +143,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
Code: string(err.Code), Code: string(err.Code),
}, },
} }
body, _ := json.Marshal(errMsg) body, marshalErr := json.Marshal(errMsg)
if marshalErr != nil {
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
}
return body, statusCode return body, statusCode
} }
@@ -218,12 +226,12 @@ func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse)
return encodeRerankResponse(resp) return encodeRerankResponse(resp)
} }
// ExtractUnifiedModelID 从路径中提取统一模型 ID/models/{provider_id}/{model_name} // ExtractUnifiedModelID 从路径中提取统一模型 ID/v1/models/{provider_id}/{model_name}
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) { func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
if !strings.HasPrefix(nativePath, "/models/") { if !strings.HasPrefix(nativePath, "/v1/models/") {
return "", fmt.Errorf("不是模型详情路径: %s", nativePath) return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
} }
suffix := nativePath[len("/models/"):] suffix := nativePath[len("/v1/models/"):]
if suffix == "" { if suffix == "" {
return "", fmt.Errorf("路径缺少模型 ID") return "", fmt.Errorf("路径缺少模型 ID")
} }
@@ -248,7 +256,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
return "", nil, err return "", nil, err
} }
rewriteFunc := func(newModel string) ([]byte, error) { rewriteFunc := func(newModel string) ([]byte, error) {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
} }
return current, rewriteFunc, nil return current, rewriteFunc, nil
@@ -282,12 +294,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
switch ifaceType { switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings: case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加 // Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
return json.Marshal(m) return json.Marshal(m)
case conversion.InterfaceTypeRerank: case conversion.InterfaceTypeRerank:
// Rerank 响应:存在 model 字段则改写,不存在则不添加 // Rerank 响应:存在 model 字段则改写,不存在则不添加
if _, exists := m["model"]; exists { if _, exists := m["model"]; exists {
m["model"], _ = json.Marshal(newModel) encodedModel, err := json.Marshal(newModel)
if err != nil {
return nil, err
}
m["model"] = encodedModel
} }
return json.Marshal(m) return json.Marshal(m)
default: default:

View File

@@ -28,11 +28,11 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
path string path string
expected conversion.InterfaceType expected conversion.InterfaceType
}{ }{
{"聊天补全", "/chat/completions", conversion.InterfaceTypeChat}, {"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
{"模型列表", "/models", conversion.InterfaceTypeModels}, {"模型列表", "/v1/models", conversion.InterfaceTypeModels},
{"模型详情", "/models/gpt-4", conversion.InterfaceTypeModelInfo}, {"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
{"嵌入接口", "/embeddings", conversion.InterfaceTypeEmbeddings}, {"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
{"重排序接口", "/rerank", conversion.InterfaceTypeRerank}, {"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough}, {"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
} }
@@ -44,19 +44,42 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
} }
} }
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
a := NewAdapter()
tests := []struct {
path string
expected conversion.InterfaceType
}{
{"/chat/completions", conversion.InterfaceTypePassthrough},
{"/models", conversion.InterfaceTypePassthrough},
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
{"/embeddings", conversion.InterfaceTypePassthrough},
{"/rerank", conversion.InterfaceTypePassthrough},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
})
}
}
func TestAdapter_BuildUrl(t *testing.T) { func TestAdapter_BuildUrl(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
nativePath string nativePath string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected string expected string
}{ }{
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"}, {"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
{"模型", "/models", conversion.InterfaceTypeModels, "/models"}, {"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
{"嵌入", "/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"}, {"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
{"重排序", "/rerank", conversion.InterfaceTypeRerank, "/rerank"}, {"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"}, {"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
} }
@@ -92,9 +115,9 @@ func TestAdapter_SupportsInterface(t *testing.T) {
a := NewAdapter() a := NewAdapter()
tests := []struct { tests := []struct {
name string name string
interfaceType conversion.InterfaceType interfaceType conversion.InterfaceType
expected bool expected bool
}{ }{
{"聊天", conversion.InterfaceTypeChat, true}, {"聊天", conversion.InterfaceTypeChat, true},
{"模型", conversion.InterfaceTypeModels, true}, {"模型", conversion.InterfaceTypeModels, true},
@@ -118,12 +141,12 @@ func TestIsModelInfoPath(t *testing.T) {
path string path string
expected bool expected bool
}{ }{
{"model_info", "/models/gpt-4", true}, {"model_info", "/v1/models/openai/gpt-4", true},
{"model_info_with_dots", "/models/gpt-4.1-preview", true}, {"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
{"models_list", "/models", false}, {"models_list", "/v1/models", false},
{"nested_path", "/models/gpt-4/versions", true}, {"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
{"empty_suffix", "/models/", false}, {"empty_suffix", "/v1/models/", false},
{"unrelated", "/chat/completions", false}, {"unrelated", "/v1/chat/completions", false},
{"partial_prefix", "/model", false}, {"partial_prefix", "/model", false},
} }
@@ -134,6 +157,27 @@ func TestIsModelInfoPath(t *testing.T) {
} }
} }
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
a := NewAdapter()
t.Run("标准路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", modelID)
})
t.Run("复杂路径", func(t *testing.T) {
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
})
t.Run("非模型详情路径报错", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err)
})
}
func TestAdapter_EncodeError_InvalidInput(t *testing.T) { func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
a := NewAdapter() a := NewAdapter()
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效") convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")

View File

@@ -18,35 +18,35 @@ func TestExtractUnifiedModelID(t *testing.T) {
a := NewAdapter() a := NewAdapter()
t.Run("standard_path", func(t *testing.T) { t.Run("standard_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/openai/gpt-4") id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "openai/gpt-4", id) assert.Equal(t, "openai/gpt-4", id)
}) })
t.Run("multi_segment_path", func(t *testing.T) { t.Run("multi_segment_path", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/azure/accounts/org/models/gpt-4") id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "azure/accounts/org/models/gpt-4", id) assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
}) })
t.Run("single_segment", func(t *testing.T) { t.Run("single_segment", func(t *testing.T) {
id, err := a.ExtractUnifiedModelID("/models/gpt-4") id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "gpt-4", id) assert.Equal(t, "gpt-4", id)
}) })
t.Run("non_model_path", func(t *testing.T) { t.Run("non_model_path", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/chat/completions") _, err := a.ExtractUnifiedModelID("/v1/chat/completions")
require.Error(t, err) require.Error(t, err)
}) })
t.Run("empty_suffix", func(t *testing.T) { t.Run("empty_suffix", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/models/") _, err := a.ExtractUnifiedModelID("/v1/models/")
require.Error(t, err) require.Error(t, err)
}) })
t.Run("models_list_no_slash", func(t *testing.T) { t.Run("models_list_no_slash", func(t *testing.T) {
_, err := a.ExtractUnifiedModelID("/models") _, err := a.ExtractUnifiedModelID("/v1/models")
require.Error(t, err) require.Error(t, err)
}) })
@@ -344,12 +344,12 @@ func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
path string path string
expected bool expected bool
}{ }{
{"simple_model_id", "/models/gpt-4", true}, {"simple_model_id", "/v1/models/gpt-4", true},
{"unified_model_id_with_slash", "/models/openai/gpt-4", true}, {"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
{"models_list", "/models", false}, {"models_list", "/v1/models", false},
{"models_list_trailing_slash", "/models/", false}, {"models_list_trailing_slash", "/v1/models/", false},
{"chat_completions", "/chat/completions", false}, {"chat_completions", "/v1/chat/completions", false},
{"deeply_nested", "/models/azure/eastus/deployments/my-dept/models/gpt-4", true}, {"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
var blocks []canonical.ContentBlock var blocks []canonical.ContentBlock
for _, item := range v { for _, item := range v {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
blocks = append(blocks, canonical.NewTextBlock(text)) blocks = append(blocks, canonical.NewTextBlock(text))
case "image_url": case "image_url":
blocks = append(blocks, canonical.ContentBlock{Type: "image"}) blocks = append(blocks, canonical.ContentBlock{Type: "image"})
@@ -242,9 +248,9 @@ func decodeUserContent(content any) []canonical.ContentBlock {
// contentPart 内容部分 // contentPart 内容部分
type contentPart struct { type contentPart struct {
Type string Type string
Text string Text string
Refusal string Refusal string
} }
// decodeContentParts 解码内容部分 // decodeContentParts 解码内容部分
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
var result []contentPart var result []contentPart
for _, item := range parts { for _, item := range parts {
if m, ok := item.(map[string]any); ok { if m, ok := item.(map[string]any); ok {
t, _ := m["type"].(string) t, ok := m["type"].(string)
if !ok {
continue
}
switch t { switch t {
case "text": case "text":
text, _ := m["text"].(string) text, ok := m["text"].(string)
if !ok {
text = ""
}
result = append(result, contentPart{Type: "text", Text: text}) result = append(result, contentPart{Type: "text", Text: text})
case "refusal": case "refusal":
refusal, _ := m["refusal"].(string) refusal, ok := m["refusal"].(string)
if !ok {
refusal = ""
}
result = append(result, contentPart{Type: "refusal", Refusal: refusal}) result = append(result, contentPart{Type: "refusal", Refusal: refusal})
} }
} }
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
case map[string]any: case map[string]any:
t, _ := v["type"].(string) t, ok := v["type"].(string)
if !ok {
return nil
}
switch t { switch t {
case "function": case "function":
if fn, ok := v["function"].(map[string]any); ok { if fn, ok := v["function"].(map[string]any); ok {
name, _ := fn["name"].(string) name, ok := fn["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "custom": case "custom":
if custom, ok := v["custom"].(map[string]any); ok { if custom, ok := v["custom"].(map[string]any); ok {
name, _ := custom["name"].(string) name, ok := custom["name"].(string)
if !ok {
name = ""
}
return canonical.NewToolChoiceNamed(name) return canonical.NewToolChoiceNamed(name)
} }
case "allowed_tools": case "allowed_tools":
if at, ok := v["allowed_tools"].(map[string]any); ok { if at, ok := v["allowed_tools"].(map[string]any); ok {
mode, _ := at["mode"].(string) mode, ok := at["mode"].(string)
if !ok {
mode = ""
}
if mode == "required" { if mode == "required" {
return canonical.NewToolChoiceAny() return canonical.NewToolChoiceAny()
} }
@@ -443,7 +470,7 @@ func decodeDeprecatedFields(req *ChatCompletionRequest) {
case map[string]any: case map[string]any:
if name, ok := v["name"].(string); ok { if name, ok := v["name"].(string); ok {
req.ToolChoice = map[string]any{ req.ToolChoice = map[string]any{
"type": "function", "type": "function",
"function": map[string]any{"name": name}, "function": map[string]any{"name": name},
} }
} }

View File

@@ -450,7 +450,7 @@ func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte
"object": "list", "object": "list",
"data": data, "data": data,
"model": resp.Model, "model": resp.Model,
"usage": resp.Usage, "usage": resp.Usage,
}) })
} }

View File

@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
require.True(t, ok)
assert.Len(t, msgs, 2) assert.Len(t, msgs, 2)
firstMsg := msgs[0].(map[string]any) firstMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "system", firstMsg["role"]) assert.Equal(t, "system", firstMsg["role"])
assert.Equal(t, "你是助手", firstMsg["content"]) assert.Equal(t, "你是助手", firstMsg["content"])
} }
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
msgs := result["messages"].([]any) msgs, ok := result["messages"].([]any)
assistantMsg := msgs[0].(map[string]any) require.True(t, ok)
assistantMsg, ok := msgs[0].(map[string]any)
require.True(t, ok)
toolCalls, ok := assistantMsg["tool_calls"].([]any) toolCalls, ok := assistantMsg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, toolCalls, 1) assert.Len(t, toolCalls, 1)
tc := toolCalls[0].(map[string]any) tc, ok := toolCalls[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "call_1", tc["id"]) assert.Equal(t, "call_1", tc["id"])
} }
@@ -100,11 +105,11 @@ func TestEncodeRequest_Thinking(t *testing.T) {
func TestEncodeResponse_Basic(t *testing.T) { func TestEncodeResponse_Basic(t *testing.T) {
sr := canonical.StopReasonEndTurn sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "resp-1", ID: "resp-1",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5}, Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
} }
body, err := encodeResponse(resp) body, err := encodeResponse(resp)
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
assert.Equal(t, "resp-1", result["id"]) assert.Equal(t, "resp-1", result["id"])
assert.Equal(t, "chat.completion", result["object"]) assert.Equal(t, "chat.completion", result["object"])
choices := result["choices"].([]any) choices, ok := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, ok)
msg := choice["message"].(map[string]any) choice, ok := choices[0].(map[string]any)
require.True(t, ok)
msg, ok := choice["message"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "你好", msg["content"]) assert.Equal(t, "你好", msg["content"])
assert.Equal(t, "stop", choice["finish_reason"]) assert.Equal(t, "stop", choice["finish_reason"])
} }
@@ -126,9 +134,9 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
sr := canonical.StopReasonToolUse sr := canonical.StopReasonToolUse
input := json.RawMessage(`{"q":"test"}`) input := json.RawMessage(`{"q":"test"}`)
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "resp-2", ID: "resp-2",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)}, Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
StopReason: &sr, StopReason: &sr,
} }
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okc := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okc)
msgMap, okm := choices[0].(map[string]any)
require.True(t, okm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
tcs, ok := msg["tool_calls"].([]any) tcs, ok := msg["tool_calls"].([]any)
require.True(t, ok) require.True(t, ok)
assert.Len(t, tcs, 1) assert.Len(t, tcs, 1)
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "list", result["object"]) assert.Equal(t, "list", result["object"])
data := result["data"].([]any) data, okd := result["data"].([]any)
require.True(t, okd)
assert.Len(t, data, 2) assert.Len(t, data, 2)
} }
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
msg := choices[0].(map[string]any)["message"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
msg, okmsg := msgMap["message"].(map[string]any)
require.True(t, okmsg)
assert.Equal(t, "回答", msg["content"]) assert.Equal(t, "回答", msg["content"])
assert.Equal(t, "思考过程", msg["reasoning_content"]) assert.Equal(t, "思考过程", msg["reasoning_content"])
} }

View File

@@ -18,9 +18,9 @@ func TestStreamDecoder_BasicText(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -56,8 +56,8 @@ func TestStreamDecoder_ToolCalls(t *testing.T) {
idx := 0 idx := 0
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -98,8 +98,8 @@ func TestStreamDecoder_Thinking(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -127,8 +127,8 @@ func TestStreamDecoder_FinishReason(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -161,8 +161,8 @@ func TestStreamDecoder_DoneSignal(t *testing.T) {
// 先发送一个文本 chunk // 先发送一个文本 chunk
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -190,8 +190,8 @@ func TestStreamDecoder_RefusalReuse(t *testing.T) {
// 连续两个 refusal delta chunk // 连续两个 refusal delta chunk
for _, text := range []string{"拒绝", "原因"} { for _, text := range []string{"拒绝", "原因"} {
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-1", "id": "chatcmpl-1",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -250,8 +250,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx0 := 0 idx0 := 0
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-mt", "id": "chatcmpl-mt",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -274,8 +274,8 @@ func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
idx1 := 1 idx1 := 1
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-mt", "id": "chatcmpl-mt",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -322,8 +322,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-multi", "id": "chatcmpl-multi",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -332,8 +332,8 @@ func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
}, },
} }
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-multi", "id": "chatcmpl-multi",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -358,8 +358,8 @@ func TestStreamDecoder_UTF8Truncation(t *testing.T) {
d := NewStreamDecoder() d := NewStreamDecoder()
chunk := map[string]any{ chunk := map[string]any{
"id": "chatcmpl-utf8", "id": "chatcmpl-utf8",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -390,8 +390,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
idx := 0 idx := 0
chunk1 := map[string]any{ chunk1 := map[string]any{
"id": "chatcmpl-tc", "id": "chatcmpl-tc",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,
@@ -412,8 +412,8 @@ func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
}, },
} }
chunk2 := map[string]any{ chunk2 := map[string]any{
"id": "chatcmpl-tc", "id": "chatcmpl-tc",
"model": "gpt-4", "model": "gpt-4",
"choices": []any{ "choices": []any{
map[string]any{ map[string]any{
"index": 0, "index": 0,

View File

@@ -10,9 +10,9 @@ import (
// StreamEncoder OpenAI 流式编码器 // StreamEncoder OpenAI 流式编码器
type StreamEncoder struct { type StreamEncoder struct {
bufferedStart *canonical.CanonicalStreamEvent bufferedStart *canonical.CanonicalStreamEvent
toolCallIndexMap map[string]int toolCallIndexMap map[string]int
nextToolCallIndex int nextToolCallIndex int
} }
// NewStreamEncoder 创建 OpenAI 流式编码器 // NewStreamEncoder 创建 OpenAI 流式编码器
@@ -195,8 +195,8 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte { func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
chunk := map[string]any{ chunk := map[string]any{
"choices": []map[string]any{{ "choices": []map[string]any{{
"index": 0, "index": 0,
"delta": delta, "delta": delta,
}}, }},
} }
return e.marshalChunk(chunk) return e.marshalChunk(chunk)

View File

@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
data := strings.TrimPrefix(s, "data: ") data := strings.TrimPrefix(s, "data: ")
data = strings.TrimRight(data, "\n") data = strings.TrimRight(data, "\n")
require.NoError(t, json.Unmarshal([]byte(data), &payload)) require.NoError(t, json.Unmarshal([]byte(data), &payload))
choices := payload["choices"].([]any) choices, okch := payload["choices"].([]any)
delta := choices[0].(map[string]any)["delta"].(map[string]any) require.True(t, okch)
msgMap, okmm := choices[0].(map[string]any)
require.True(t, okmm)
delta, okd := msgMap["delta"].(map[string]any)
require.True(t, okd)
assert.Equal(t, "assistant", delta["role"]) assert.Equal(t, "assistant", delta["role"])
} }

View File

@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
assert.Equal(t, "rerank-1", result["model"]) assert.Equal(t, "rerank-1", result["model"])
results := result["results"].([]any) results, okr := result["results"].([]any)
require.True(t, okr)
assert.Len(t, results, 1) assert.Len(t, results, 1)
} }
@@ -356,9 +357,9 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
reasoning := 20 reasoning := 20
sr := canonical.StopReasonEndTurn sr := canonical.StopReasonEndTurn
resp := &canonical.CanonicalResponse{ resp := &canonical.CanonicalResponse{
ID: "r1", ID: "r1",
Model: "gpt-4", Model: "gpt-4",
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")}, Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
StopReason: &sr, StopReason: &sr,
Usage: canonical.CanonicalUsage{ Usage: canonical.CanonicalUsage{
InputTokens: 100, InputTokens: 100,
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
usage := result["usage"].(map[string]any) usage, oku := result["usage"].(map[string]any)
require.True(t, oku)
assert.Equal(t, float64(100), usage["prompt_tokens"]) assert.Equal(t, float64(100), usage["prompt_tokens"])
ptd, ok := usage["prompt_tokens_details"].(map[string]any) ptd, ok := usage["prompt_tokens_details"].(map[string]any)
require.True(t, ok) require.True(t, ok)
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
var result map[string]any var result map[string]any
require.NoError(t, json.Unmarshal(body, &result)) require.NoError(t, json.Unmarshal(body, &result))
choices := result["choices"].([]any) choices, okch := result["choices"].([]any)
choice := choices[0].(map[string]any) require.True(t, okch)
choice, okc := choices[0].(map[string]any)
require.True(t, okc)
assert.Equal(t, tt.want, choice["finish_reason"]) assert.Equal(t, tt.want, choice["finish_reason"])
}) })
} }

View File

@@ -4,42 +4,42 @@ import "encoding/json"
// ChatCompletionRequest OpenAI Chat Completion 请求 // ChatCompletionRequest OpenAI Chat Completion 请求
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"` PresencePenalty *float64 `json:"presence_penalty,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
N *int `json:"n,omitempty"` N *int `json:"n,omitempty"`
Seed *int `json:"seed,omitempty"` Seed *int `json:"seed,omitempty"`
Logprobs *bool `json:"logprobs,omitempty"` Logprobs *bool `json:"logprobs,omitempty"`
TopLogprobs *int `json:"top_logprobs,omitempty"` TopLogprobs *int `json:"top_logprobs,omitempty"`
// 已废弃字段 // 已废弃字段
Functions []FunctionDef `json:"functions,omitempty"` Functions []FunctionDef `json:"functions,omitempty"`
FunctionCall any `json:"function_call,omitempty"` FunctionCall any `json:"function_call,omitempty"`
} }
// Message OpenAI 消息 // Message OpenAI 消息
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content any `json:"content"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"`
// 已废弃 // 已废弃
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"` FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
@@ -88,8 +88,8 @@ type FunctionDef struct {
// ResponseFormat OpenAI 响应格式 // ResponseFormat OpenAI 响应格式
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"` JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
} }
// JSONSchemaDef JSON Schema 定义 // JSONSchemaDef JSON Schema 定义
@@ -118,7 +118,7 @@ type ChatCompletionResponse struct {
// Choice OpenAI 选择项 // Choice OpenAI 选择项
type Choice struct { type Choice struct {
Index int `json:"index"` Index int `json:"index"`
Message *Message `json:"message,omitempty"` Message *Message `json:"message,omitempty"`
Delta *Message `json:"delta,omitempty"` Delta *Message `json:"delta,omitempty"`
FinishReason *string `json:"finish_reason"` FinishReason *string `json:"finish_reason"`
@@ -127,10 +127,10 @@ type Choice struct {
// Usage OpenAI 用量 // Usage OpenAI 用量
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"` PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
} }

View File

@@ -1,6 +1,11 @@
package conversion package conversion
import "nex/backend/internal/conversion/canonical" import (
"bytes"
"strings"
"nex/backend/internal/conversion/canonical"
)
// StreamDecoder 流式解码器接口 // StreamDecoder 流式解码器接口
type StreamDecoder interface { type StreamDecoder interface {
@@ -39,11 +44,12 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
} }
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器 // SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
// 逐 chunk 改写 model 字段 // 按 SSE frame 改写 data JSON 中的 model 字段
type SmartPassthroughStreamConverter struct { type SmartPassthroughStreamConverter struct {
adapter ProtocolAdapter adapter ProtocolAdapter
modelOverride string modelOverride string
interfaceType InterfaceType interfaceType InterfaceType
buffer []byte
} }
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器 // NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
@@ -55,24 +61,45 @@ func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride s
} }
} }
// ProcessChunk 改写 chunk 中的 model 字段 // ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte { func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
if len(rawChunk) == 0 { if len(rawChunk) == 0 {
return nil return nil
} }
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType) c.buffer = append(c.buffer, rawChunk...)
if err != nil { frames, rest := splitSSEFrames(c.buffer)
// 改写失败,返回原始 chunk c.buffer = rest
return [][]byte{rawChunk}
}
return [][]byte{rewrittenChunk} result := make([][]byte, 0, len(frames))
for _, frame := range frames {
result = append(result, c.rewriteFrame(frame))
}
return result
} }
// Flush 无缓冲数据 func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
payload, ok := sseFrameDataPayload(frame)
if !ok || strings.TrimSpace(payload) == "[DONE]" {
return frame
}
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
if err != nil {
return frame
}
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
}
// Flush 输出未形成完整 frame 的剩余数据
func (c *SmartPassthroughStreamConverter) Flush() [][]byte { func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
return nil if len(c.buffer) == 0 {
return nil
}
frame := append([]byte(nil), c.buffer...)
c.buffer = nil
return [][]byte{c.rewriteFrame(frame)}
} }
// CanonicalStreamConverter 跨协议规范流式转换器 // CanonicalStreamConverter 跨协议规范流式转换器
@@ -153,3 +180,86 @@ func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.Canonical
event.Message.Model = c.modelOverride event.Message.Model = c.modelOverride
} }
} }
func splitSSEFrames(data []byte) ([][]byte, []byte) {
var frames [][]byte
for len(data) > 0 {
idx, sepLen := findSSEFrameSeparator(data)
if idx < 0 {
break
}
end := idx + sepLen
frames = append(frames, append([]byte(nil), data[:end]...))
data = data[end:]
}
return frames, data
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
lineEnding, separator := sseLineEnding(frame)
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
out := make([]string, 0, len(lines)+1)
dataWritten := false
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
if !dataWritten {
for _, dataLine := range strings.Split(data, "\n") {
out = append(out, "data: "+dataLine)
}
dataWritten = true
}
continue
}
out = append(out, line)
}
if !dataWritten {
out = append(out, "data: "+data)
}
return []byte(strings.Join(out, lineEnding) + separator)
}
func sseLineEnding(frame []byte) (string, string) {
if bytes.Contains(frame, []byte("\r\n")) {
return "\r\n", "\r\n\r\n"
}
return "\n", "\n\n"
}

View File

@@ -61,7 +61,7 @@ func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error)
return gorm.Open(mysql.Open(dsn), gormConfig) return gorm.Open(mysql.Open(dsn), gormConfig)
default: default:
dbDir := filepath.Dir(cfg.Path) dbDir := filepath.Dir(cfg.Path)
if err := os.MkdirAll(dbDir, 0755); err != nil { if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err) return nil, fmt.Errorf("创建数据库目录失败: %w", err)
} }
if zapLogger != nil { if zapLogger != nil {
@@ -95,7 +95,9 @@ func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
zap.String("dir", migrationsSubDir)) zap.String("dir", migrationsSubDir))
} }
goose.SetDialect(gooseDialect) if err := goose.SetDialect(gooseDialect); err != nil {
return err
}
if err := goose.Up(sqlDB, migrationsDir); err != nil { if err := goose.Up(sqlDB, migrationsDir); err != nil {
return err return err
} }

View File

@@ -4,11 +4,11 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"nex/backend/internal/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/config"
) )
func TestInit_SQLite(t *testing.T) { func TestInit_SQLite(t *testing.T) {

View File

@@ -13,4 +13,3 @@ type Provider struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }

View File

@@ -6,13 +6,13 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
) )
func TestProviderHandler_CreateProvider_Success(t *testing.T) { func TestProviderHandler_CreateProvider_Success(t *testing.T) {
@@ -24,9 +24,9 @@ func TestProviderHandler_CreateProvider_Success(t *testing.T) {
h := NewProviderHandler(mockSvc) h := NewProviderHandler(mockSvc)
body, _ := json.Marshal(map[string]string{ body, _ := json.Marshal(map[string]string{
"id": "p1", "id": "p1",
"name": "Test", "name": "Test",
"api_key": "sk-test", "api_key": "sk-test",
"base_url": "https://api.test.com", "base_url": "https://api.test.com",
}) })
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -9,23 +9,22 @@ import (
"strings" "strings"
"testing" "testing"
"nex/backend/internal/domain"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"gorm.io/gorm" "gorm.io/gorm"
"nex/backend/internal/domain"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) { func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()

View File

@@ -20,8 +20,7 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
if id, ok := requestID.(string); ok { if id, ok := requestID.(string); ok {
requestIDStr = id requestIDStr = id
} }
logger.Debug("请求开始",
logger.Info("请求开始",
pkglogger.Method(c.Request.Method), pkglogger.Method(c.Request.Method),
pkglogger.Path(path), pkglogger.Path(path),
pkglogger.Query(query), pkglogger.Query(query),
@@ -34,7 +33,7 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
latency := time.Since(start) latency := time.Since(start)
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
logger.Info("请求结束", logger.Debug("请求结束",
pkglogger.StatusCode(statusCode), pkglogger.StatusCode(statusCode),
pkglogger.Method(c.Request.Method), pkglogger.Method(c.Request.Method),
pkglogger.Path(path), pkglogger.Path(path),

View File

@@ -7,6 +7,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
) )
func init() { func init() {
@@ -65,6 +67,61 @@ func TestLogging(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) {
core, logs := observer.New(zapcore.InfoLevel)
logger := zap.New(core)
w := serveLoggingRequest(logger)
assert.Equal(t, 200, w.Code)
assert.Empty(t, logs.FilterMessage("请求开始").All())
assert.Empty(t, logs.FilterMessage("请求结束").All())
}
func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) {
core, logs := observer.New(zapcore.DebugLevel)
logger := zap.New(core)
w := serveLoggingRequest(logger)
assert.Equal(t, 200, w.Code)
startLogs := logs.FilterMessage("请求开始").All()
endLogs := logs.FilterMessage("请求结束").All()
if assert.Len(t, startLogs, 1) {
fields := startLogs[0].ContextMap()
assert.Equal(t, "GET", fields["method"])
assert.Equal(t, "/test", fields["path"])
assert.Equal(t, "key=value", fields["query"])
assert.Equal(t, "existing-id-123", fields["request_id"])
assert.NotEmpty(t, fields["client_ip"])
}
if assert.Len(t, endLogs, 1) {
fields := endLogs[0].ContextMap()
assert.Equal(t, int64(200), fields["status"])
assert.Equal(t, "GET", fields["method"])
assert.Equal(t, "/test", fields["path"])
assert.Equal(t, int64(2), fields["body_size"])
assert.Equal(t, "existing-id-123", fields["request_id"])
assert.Contains(t, fields, "latency")
}
}
func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder {
r := gin.New()
r.Use(RequestID())
r.Use(Logging(logger))
r.GET("/test", func(c *gin.Context) {
c.String(200, "ok")
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/test?key=value", nil)
req.Header.Set("X-Request-ID", "existing-id-123")
r.ServeHTTP(w, req)
return w
}
func TestRecovery_NoPanic(t *testing.T) { func TestRecovery_NoPanic(t *testing.T) {
logger := zap.NewNop() logger := zap.NewNop()

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ModelHandler 模型管理处理器 // ModelHandler 模型管理处理器
@@ -58,16 +58,16 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
err := h.modelService.Create(model) err := h.modelService.Create(model)
if err != nil { if err != nil {
if err == appErrors.ErrProviderNotFound { if errors.Is(err, appErrors.ErrProviderNotFound) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": "供应商不存在", "error": "供应商不存在",
}) })
return return
} }
if err == appErrors.ErrDuplicateModel { if errors.Is(err, appErrors.ErrDuplicateModel) {
c.JSON(http.StatusConflict, gin.H{ c.JSON(http.StatusConflict, gin.H{
"error": "同一供应商下模型名称已存在", "error": "同一供应商下模型名称已存在",
"code": appErrors.ErrDuplicateModel.Code, "code": appErrors.ErrDuplicateModel.Code,
}) })
return return
} }
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
model, err := h.modelService.Get(id) model, err := h.modelService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
err := h.modelService.Delete(id) err := h.modelService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "模型未找到", "error": "模型未找到",
}) })

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"net/http" "net/http"
"nex/backend/internal/domain"
"nex/backend/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain"
"nex/backend/internal/service"
) )
// ProviderHandler 供应商管理处理器 // ProviderHandler 供应商管理处理器
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
err := h.providerService.Create(provider) err := h.providerService.Create(provider)
if err != nil { if err != nil {
if err == appErrors.ErrInvalidProviderID { if errors.Is(err, appErrors.ErrInvalidProviderID) {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": appErrors.ErrInvalidProviderID.Message, "error": appErrors.ErrInvalidProviderID.Message,
"code": appErrors.ErrInvalidProviderID.Code, "code": appErrors.ErrInvalidProviderID.Code,
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
provider, err := h.providerService.Get(id) provider, err := h.providerService.Get(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
err := h.providerService.Update(id, req) err := h.providerService.Update(id, req)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
err := h.providerService.Delete(id) err := h.providerService.Delete(id)
if err != nil { if err != nil {
if err == gorm.ErrRecordNotFound { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": "供应商未找到", "error": "供应商未找到",
}) })

View File

@@ -3,30 +3,33 @@ package handler
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"nex/backend/internal/conversion" "nex/backend/internal/conversion"
"nex/backend/internal/conversion/canonical" "nex/backend/internal/conversion/canonical"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/provider" "nex/backend/internal/provider"
"nex/backend/internal/service" "nex/backend/internal/service"
appErrors "nex/backend/pkg/errors"
"nex/backend/pkg/modelid" "nex/backend/pkg/modelid"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
// ProxyHandler 统一代理处理器 // ProxyHandler 统一代理处理器
type ProxyHandler struct { type ProxyHandler struct {
engine *conversion.ConversionEngine engine *conversion.ConversionEngine
client provider.ProviderClient client provider.ProviderClient
routingService service.RoutingService routingService service.RoutingService
providerService service.ProviderService providerService service.ProviderService
statsService service.StatsService statsService service.StatsService
logger *zap.Logger logger *zap.Logger
} }
// NewProxyHandler 创建统一代理处理器 // NewProxyHandler 创建统一代理处理器
@@ -46,7 +49,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 从 URL 提取 clientProtocol: /{protocol}/v1/... // 从 URL 提取 clientProtocol: /{protocol}/v1/...
clientProtocol := c.Param("protocol") clientProtocol := c.Param("protocol")
if clientProtocol == "" { if clientProtocol == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"}) h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
return return
} }
@@ -56,12 +59,13 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
path = "/" + path path = "/" + path
} }
nativePath := path nativePath := path
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
// 获取 client adapter // 获取 client adapter
registry := h.engine.GetRegistry() registry := h.engine.GetRegistry()
clientAdapter, err := registry.Get(clientProtocol) clientAdapter, err := registry.Get(clientProtocol)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return return
} }
@@ -78,7 +82,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
if ifaceType == conversion.InterfaceTypeModelInfo { if ifaceType == conversion.InterfaceTypeModelInfo {
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath) unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"}) h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
return return
} }
h.handleModelInfo(c, unifiedID, clientAdapter) h.handleModelInfo(c, unifiedID, clientAdapter)
@@ -88,40 +92,50 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
// 读取请求体 // 读取请求体
body, err := io.ReadAll(c.Request.Body) body, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
return return
} }
// 解析统一模型 ID使用 adapter.ExtractModelName
var providerID, modelName string
if len(body) > 0 {
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err == nil && unifiedID != "" {
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
if err == nil {
providerID = pid
modelName = mn
}
}
}
// 构建输入 HTTPRequestSpec // 构建输入 HTTPRequestSpec
inSpec := conversion.HTTPRequestSpec{ inSpec := conversion.HTTPRequestSpec{
URL: nativePath, URL: requestPath,
Method: c.Request.Method, Method: c.Request.Method,
Headers: extractHeaders(c), Headers: extractHeaders(c),
Body: body, Body: body,
} }
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
if err != nil {
if isInvalidJSONError(err) {
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
return
}
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil {
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
if providerID == "" || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
return
}
// 路由 // 路由
routeResult, err := h.routingService.RouteByModelName(providerID, modelName) routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
if err != nil { if err != nil {
// GET 请求或无法提取 model 时,直接转发到上游 h.writeRouteError(c, err)
if len(body) == 0 || modelName == "" {
h.forwardPassthrough(c, inSpec, clientProtocol)
return
}
h.writeError(c, err, clientProtocol)
return return
} }
@@ -138,12 +152,9 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
targetProvider := conversion.NewTargetProvider( targetProvider := conversion.NewTargetProvider(
routeResult.Provider.BaseURL, routeResult.Provider.BaseURL,
routeResult.Provider.APIKey, routeResult.Provider.APIKey,
routeResult.Model.ModelName, // 上游模型名,用于请求改写 routeResult.Model.ModelName, // 上游模型名,用于请求改写
) )
// 判断是否流式
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
// 计算统一模型 ID用于响应覆写 // 计算统一模型 ID用于响应覆写
unifiedModelID := routeResult.Model.UnifiedModelID() unifiedModelID := routeResult.Model.UnifiedModelID()
@@ -154,12 +165,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
} }
} }
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
switch ifaceType {
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
return true
default:
return false
}
}
func isInvalidJSONError(err error) bool {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
}
func appendRawQuery(path, rawQuery string) string {
if rawQuery == "" {
return path
}
return path + "?" + rawQuery
}
// handleNonStream 处理非流式请求 // handleNonStream 处理非流式请求
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) { func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
// 转换请求 // 转换请求
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider) outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
if err != nil { if err != nil {
h.logger.Error("转换请求失败", zap.String("error", err.Error())) h.logger.Error("转换请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
@@ -167,31 +200,27 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
// 发送请求 // 发送请求
resp, err := h.client.Send(c.Request.Context(), *outSpec) resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.logger.Error("发送请求失败", zap.String("error", err.Error())) h.logger.Error("发送请求失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return return
} }
// 转换响应,传入 modelOverride跨协议场景覆写 model 字段) // 转换响应,传入 modelOverride跨协议场景覆写 model 字段)
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID) convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
if err != nil { if err != nil {
h.logger.Error("转换响应失败", zap.String("error", err.Error())) h.logger.Error("转换响应失败", zap.Error(err))
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
} }
// 设置响应头 h.writeConvertedResponse(c, *convertedResp)
for k, v := range convertedResp.Headers {
c.Header(k, v)
}
if c.GetHeader("Content-Type") == "" {
c.Header("Content-Type", "application/json")
}
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
@@ -204,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
return return
} }
// 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段) // 发送流式请求
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType) streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeUpstreamUnavailable(c, err)
return
}
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return return
} }
// 发送流式请求 // 创建流式转换器,传入 modelOverride跨协议场景覆写 model 字段)
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec) streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeConversionError(c, err, clientProtocol)
return return
@@ -223,37 +260,61 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer) writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range eventChan { for event := range streamResp.Events {
if event.Error != nil { if event.Error != nil {
h.logger.Error("流读取错误", zap.String("error", event.Error.Error())) h.logger.Error("流读取错误", zap.Error(event.Error))
break break
} }
if event.Done { if event.Done {
// flush 转换器 // flush 转换器
chunks := streamConverter.Flush() chunks := streamConverter.Flush()
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush()
} }
flushed = true
break break
} }
chunks := streamConverter.ProcessChunk(event.Data) chunks := streamConverter.ProcessChunk(event.Data)
for _, chunk := range chunks { if err := h.writeStreamChunks(writer, chunks); err != nil {
writer.Write(chunk) h.logger.Warn("流式响应写回失败", zap.Error(err))
writer.Flush() break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("流式响应写回失败", zap.Error(err))
} }
} }
go func() { go func() {
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) _ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
}() }()
} }
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
for _, chunk := range chunks {
if _, err := writer.Write(chunk); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}
}
return nil
}
// isStreamRequest 判断是否流式请求 // isStreamRequest 判断是否流式请求
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool { func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol) ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
if err != nil {
return false
}
if ifaceType != conversion.InterfaceTypeChat { if ifaceType != conversion.InterfaceTypeChat {
return false return false
} }
@@ -272,8 +333,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 从数据库查询所有启用的模型 // 从数据库查询所有启用的模型
models, err := h.providerService.ListEnabledModels() models, err := h.providerService.ListEnabledModels()
if err != nil { if err != nil {
h.logger.Error("查询启用模型失败", zap.String("error", err.Error())) h.logger.Error("查询启用模型失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
return return
} }
@@ -294,8 +355,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
// 使用 adapter 编码返回 // 使用 adapter 编码返回
body, err := adapter.EncodeModelsResponse(modelList) body, err := adapter.EncodeModelsResponse(modelList)
if err != nil { if err != nil {
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error())) h.logger.Error("编码 Models 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return return
} }
@@ -307,17 +368,14 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
// 解析统一模型 ID // 解析统一模型 ID
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID) providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{ h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
"error": "无效的统一模型 ID 格式",
"code": "INVALID_MODEL_ID",
})
return return
} }
// 从数据库查询模型 // 从数据库查询模型
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName) model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
if err != nil { if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"}) h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
return return
} }
@@ -333,41 +391,103 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
body, err := adapter.EncodeModelInfoResponse(modelInfo) body, err := adapter.EncodeModelInfoResponse(modelInfo)
if err != nil { if err != nil {
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err)) h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
return return
} }
c.Data(http.StatusOK, "application/json", body) c.Data(http.StatusOK, "application/json", body)
} }
// writeConversionError 写入转换错误 // writeConversionError 写入网关层转换错误
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) { func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
if convErr, ok := err.(*conversion.ConversionError); ok { var convErr *conversion.ConversionError
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol) if errors.As(err, &convErr) {
c.Data(statusCode, "application/json", body) statusCode, code, message := mapConversionError(convErr)
h.writeProxyError(c, statusCode, code, message)
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
} }
// writeError 写入路由错误 func mapConversionError(err *conversion.ConversionError) (int, string, string) {
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) { switch err.Code {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) case conversion.ErrorCodeJSONParseError:
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
}
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
case conversion.ErrorCodeInvalidInput,
conversion.ErrorCodeMissingRequiredField,
conversion.ErrorCodeProtocolConstraint:
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
case conversion.ErrorCodeInterfaceNotSupported:
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
case conversion.ErrorCodeUnsupportedMultimodal:
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
default:
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
}
}
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
if appErr, ok := appErrors.AsAppError(err); ok {
switch appErr.Code {
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
default:
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
}
return
}
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
}
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
h.logger.Error("上游不可达", zap.Error(err))
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
}
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
c.JSON(status, gin.H{
"error": message,
"code": code,
})
}
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range resp.Headers {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(resp.StatusCode, contentType, resp.Body)
}
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
for k, v := range filterHopByHopHeaders(resp.Headers) {
c.Header(k, v)
}
contentType := headerValue(resp.Headers, "Content-Type")
c.Data(resp.StatusCode, contentType, resp.Body)
} }
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求) // forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) { func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
registry := h.engine.GetRegistry() registry := h.engine.GetRegistry()
adapter, err := registry.Get(clientProtocol) adapter, err := registry.Get(clientProtocol)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol}) h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
return return
} }
providers, err := h.providerService.List() providers, err := h.providerService.List()
if err != nil || len(providers) == 0 { if err != nil || len(providers) == 0 {
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL)) h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"}) h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
return return
} }
@@ -377,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
providerProtocol = "openai" providerProtocol = "openai"
} }
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "") targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
var outSpec *conversion.HTTPRequestSpec var outSpec *conversion.HTTPRequestSpec
if clientProtocol == providerProtocol { if clientProtocol == providerProtocol {
upstreamURL := p.BaseURL + inSpec.URL upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
headers := adapter.BuildHeaders(targetProvider) headers := adapter.BuildHeaders(targetProvider)
if _, ok := headers["Content-Type"]; !ok { if _, ok := headers["Content-Type"]; !ok {
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
} }
outSpec = &conversion.HTTPRequestSpec{ outSpec = &conversion.HTTPRequestSpec{
URL: upstreamURL, URL: joinBaseURL(p.BaseURL, upstreamPath),
Method: inSpec.Method, Method: inSpec.Method,
Headers: headers, Headers: headers,
Body: inSpec.Body, Body: inSpec.Body,
@@ -402,9 +521,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
} }
} }
if isStream {
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
return
}
resp, err := h.client.Send(c.Request.Context(), *outSpec) resp, err := h.client.Send(c.Request.Context(), *outSpec)
if err != nil { if err != nil {
h.writeConversionError(c, err, clientProtocol) h.writeUpstreamUnavailable(c, err)
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
h.writeUpstreamResponse(c, *resp)
return return
} }
@@ -414,13 +542,111 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
return return
} }
for k, v := range convertedResp.Headers { h.writeConvertedResponse(c, *convertedResp)
c.Header(k, v) }
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
if err != nil {
h.writeUpstreamUnavailable(c, err)
return
} }
if c.GetHeader("Content-Type") == "" { if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
c.Header("Content-Type", "application/json") h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
StatusCode: streamResp.StatusCode,
Headers: streamResp.Headers,
Body: streamResp.Body,
})
return
} }
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
if err != nil {
h.writeConversionError(c, err, clientProtocol)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
writer := bufio.NewWriter(c.Writer)
flushed := false
for event := range streamResp.Events {
if event.Error != nil {
h.logger.Error("透传流读取错误", zap.Error(event.Error))
break
}
if event.Done {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
flushed = true
break
}
chunks := streamConverter.ProcessChunk(event.Data)
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
break
}
}
if !flushed {
chunks := streamConverter.Flush()
if err := h.writeStreamChunks(writer, chunks); err != nil {
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
}
}
}
func stripRawQuery(path string) string {
pathOnly, _, _ := strings.Cut(path, "?")
return pathOnly
}
func rawQueryFromPath(path string) string {
_, rawQuery, found := strings.Cut(path, "?")
if !found {
return ""
}
return rawQuery
}
func joinBaseURL(baseURL, path string) string {
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
}
func headerValue(headers map[string]string, key string) string {
for k, v := range headers {
if strings.EqualFold(k, key) {
return v
}
}
return ""
}
func filterHopByHopHeaders(headers map[string]string) map[string]string {
if len(headers) == 0 {
return nil
}
hopByHop := map[string]struct{}{
"connection": {},
"transfer-encoding": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"upgrade": {},
}
filtered := make(map[string]string, len(headers))
for k, v := range headers {
if _, skip := hopByHop[strings.ToLower(k)]; skip {
continue
}
filtered[k] = v
}
return filtered
} }
// extractHeaders 从 Gin context 提取请求头 // extractHeaders 从 Gin context 提取请求头

View File

@@ -5,30 +5,30 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
"nex/backend/tests/mocks"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/conversion"
"nex/backend/internal/conversion/anthropic"
"nex/backend/internal/conversion/openai"
"nex/backend/internal/domain"
"nex/backend/internal/provider"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
"nex/backend/tests/mocks"
) )
func init() { func init() {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
} }
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine { func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
t.Helper() t.Helper()
registry := conversion.NewMemoryRegistry() registry := conversion.NewMemoryRegistry()
@@ -74,7 +74,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -93,8 +93,8 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -110,20 +110,20 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(nil, appErrors.ErrModelNotFound) routingSvc.EXPECT().RouteByModelName("unknown", "model").Return(nil, appErrors.ErrModelNotFound)
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return(nil, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc) h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
} }
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) { func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
@@ -132,7 +132,7 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -146,11 +146,12 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
} }
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) { func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
@@ -159,7 +160,7 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -173,11 +174,12 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
} }
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) { func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
@@ -186,12 +188,12 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
@@ -200,7 +202,7 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")} ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true} ch <- provider.StreamEvent{Done: true}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -209,13 +211,14 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Contains(t, w.Body.String(), "Hello") assert.Contains(t, w.Body.String(), "Hello")
assert.Contains(t, w.Body.String(), "p1/gpt-4")
} }
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) { func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
@@ -224,12 +227,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
return nil, context.DeadlineExceeded return nil, context.DeadlineExceeded
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
@@ -238,11 +241,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 502, w.Code)
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
} }
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) { func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
@@ -262,8 +266,8 @@ func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -283,11 +287,11 @@ func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/unknown/models", nil) c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 400, w.Code) assert.Equal(t, 404, w.Code)
} }
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) { func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
@@ -305,8 +309,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -330,7 +334,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -348,8 +352,8 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -372,6 +376,7 @@ func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
h.writeConversionError(c, context.DeadlineExceeded, "openai") h.writeConversionError(c, context.DeadlineExceeded, "openai")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.JSONEq(t, `{"error":"context deadline exceeded","code":"CONVERSION_FAILED"}`, w.Body.String())
} }
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) { func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
@@ -391,7 +396,40 @@ func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request") convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
h.writeConversionError(c, convErr, "openai") h.writeConversionError(c, convErr, "openai")
assert.Equal(t, 500, w.Code) assert.Equal(t, 400, w.Code)
assert.JSONEq(t, `{"error":"bad request","code":"INVALID_REQUEST"}`, w.Body.String())
}
func TestProxyHandler_WriteConversionError_JSONPhase(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
t.Run("request json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewRequestJSONParseError("解码请求失败", context.Canceled), "openai")
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
})
t.Run("response json parse error", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil)
h.writeConversionError(c, conversion.NewResponseJSONParseError("解码响应失败", context.Canceled), "openai")
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.JSONEq(t, `{"error":"解码响应失败","code":"CONVERSION_FAILED"}`, w.Body.String())
})
} }
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) { func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
@@ -411,8 +449,8 @@ func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -424,19 +462,19 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")} ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")} ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -445,8 +483,8 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -461,12 +499,12 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
@@ -474,7 +512,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")} ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true} ch <- provider.StreamEvent{Done: true}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -483,8 +521,8 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -506,7 +544,7 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -517,8 +555,8 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
@@ -533,7 +571,7 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
require.NoError(t, registry.Register(openai.NewAdapter())) require.NoError(t, registry.Register(openai.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -544,8 +582,8 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
@@ -561,7 +599,7 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
require.NoError(t, registry.Register(anthropic.NewAdapter())) require.NoError(t, registry.Register(anthropic.NewAdapter()))
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
}, nil) }, nil)
@@ -579,8 +617,8 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
@@ -592,7 +630,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
engine := setupProxyEngine(t) engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl) routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{ routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true}, Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil) }, nil)
@@ -611,8 +649,8 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -643,8 +681,8 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -667,8 +705,8 @@ func TestProxyHandler_ForwardPassthrough_NoBody_NoModel(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -691,10 +729,10 @@ func TestIsStreamRequest_EdgeCases(t *testing.T) {
path string path string
expected bool expected bool
}{ }{
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/chat/completions", true}, {"stream at end of JSON", `{"messages":[],"stream":true}`, "/v1/chat/completions", true},
{"stream with spaces", `{"stream" : true}`, "/chat/completions", true}, {"stream with spaces", `{"stream" : true}`, "/v1/chat/completions", true},
{"stream embedded in string value", `{"model":"stream:true"}`, "/chat/completions", false}, {"stream embedded in string value", `{"model":"stream:true"}`, "/v1/chat/completions", false},
{"empty body", "", "/chat/completions", false}, {"empty body", "", "/v1/chat/completions", false},
{"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false}, {"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false},
} }
@@ -721,8 +759,9 @@ func TestProxyHandler_WriteError_RouteError(t *testing.T) {
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/", nil) c.Request = httptest.NewRequest("POST", "/", nil)
h.writeError(c, fmt.Errorf("model not found"), "openai") h.writeRouteError(c, fmt.Errorf("model not found"))
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
} }
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) { func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
@@ -742,8 +781,8 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -766,35 +805,35 @@ func TestIsStreamRequest(t *testing.T) {
name: "stream true", name: "stream true",
body: []byte(`{"model": "gpt-4", "stream": true}`), body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: true, expected: true,
}, },
{ {
name: "stream false", name: "stream false",
body: []byte(`{"model": "gpt-4", "stream": false}`), body: []byte(`{"model": "gpt-4", "stream": false}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: false, expected: false,
}, },
{ {
name: "no stream field", name: "no stream field",
body: []byte(`{"model": "gpt-4"}`), body: []byte(`{"model": "gpt-4"}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: false, expected: false,
}, },
{ {
name: "invalid json", name: "invalid json",
body: []byte(`{invalid}`), body: []byte(`{invalid}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/chat/completions", nativePath: "/v1/chat/completions",
expected: false, expected: false,
}, },
{ {
name: "not chat endpoint", name: "not chat endpoint",
body: []byte(`{"model": "gpt-4", "stream": true}`), body: []byte(`{"model": "gpt-4", "stream": true}`),
clientProtocol: "openai", clientProtocol: "openai",
nativePath: "/models", nativePath: "/v1/models",
expected: false, expected: false,
}, },
{ {
@@ -832,8 +871,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
c.Request = httptest.NewRequest("GET", "/openai/models", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -844,7 +883,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
require.True(t, ok) require.True(t, ok)
assert.Len(t, data, 2) assert.Len(t, data, 2)
first := data[0].(map[string]interface{}) first, ok2 := data[0].(map[string]interface{})
require.True(t, ok2)
assert.Equal(t, "openai/gpt-4", first["id"]) assert.Equal(t, "openai/gpt-4", first["id"])
} }
@@ -862,8 +902,8 @@ func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/openai/gpt-4"}}
c.Request = httptest.NewRequest("GET", "/openai/models/openai/gpt-4", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -896,8 +936,8 @@ func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testi
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/"}}
c.Request = httptest.NewRequest("GET", "/openai/models/", nil) c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -918,7 +958,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) { client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
var req map[string]interface{} var req map[string]interface{}
json.Unmarshal(spec.Body, &req) require.NoError(t, json.Unmarshal(spec.Body, &req))
assert.Equal(t, "gpt-4", req["model"]) assert.Equal(t, "gpt-4", req["model"])
return &conversion.HTTPResponseSpec{ return &conversion.HTTPResponseSpec{
@@ -934,8 +974,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -972,8 +1012,8 @@ func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -994,7 +1034,7 @@ func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true}, Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
}, nil) }, nil)
client := mocks.NewMockProviderClient(ctrl) client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) { client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
ch := make(chan provider.StreamEvent, 10) ch := make(chan provider.StreamEvent, 10)
go func() { go func() {
defer close(ch) defer close(ch)
@@ -1012,7 +1052,7 @@ data: {"type":"message_stop"}
`)} `)}
ch <- provider.StreamEvent{Done: true} ch <- provider.StreamEvent{Done: true}
}() }()
return ch, nil return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
}) })
providerSvc := mocks.NewMockProviderService(ctrl) providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl) statsSvc := mocks.NewMockStatsService(ctrl)
@@ -1021,8 +1061,8 @@ data: {"type":"message_stop"}
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -1059,8 +1099,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
@@ -1090,8 +1130,8 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}} c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`))) c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c) h.HandleProxy(c)
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
@@ -1100,3 +1140,314 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "error") assert.Contains(t, resp, "error")
} }
func TestProxyHandler_HandleProxy_OpenAIAndAnthropicNativePaths(t *testing.T) {
tests := []struct {
name string
protocol string
path string
requestPath string
baseURL string
expectedURL string
body string
responseBody string
responseModel string
}{
{
name: "openai path keeps v1 after gateway prefix",
protocol: "openai",
path: "/v1/chat/completions",
requestPath: "/openai/v1/chat/completions",
baseURL: "https://api.test.com/v1",
expectedURL: "https://api.test.com/v1/chat/completions",
body: `{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`,
responseModel: "p1/gpt-4",
},
{
name: "anthropic path keeps v1 after gateway prefix",
protocol: "anthropic",
path: "/v1/messages",
requestPath: "/anthropic/v1/messages",
baseURL: "https://api.anthropic.test",
expectedURL: "https://api.anthropic.test/v1/messages",
body: `{"model":"p1/gpt-4","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`,
responseBody: `{"id":"msg-1","type":"message","role":"assistant","model":"gpt-4","content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`,
responseModel: "p1/gpt-4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: tt.baseURL, Protocol: tt.protocol, Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, tt.expectedURL, spec.URL)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(tt.responseBody),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: tt.protocol}, {Key: "path", Value: tt.path}}
c.Request = httptest.NewRequest("POST", tt.requestPath, bytes.NewReader([]byte(tt.body)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), tt.responseModel)
})
}
}
func TestProxyHandler_UpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).Return(&conversion.HTTPResponseSpec{
StatusCode: http.StatusTooManyRequests,
Headers: map[string]string{
"Content-Type": "application/json",
"X-Upstream-Error": "rate-limit",
"Transfer-Encoding": "chunked",
},
Body: []byte(`{"error":{"message":"rate limited"}}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusTooManyRequests, w.Code)
assert.JSONEq(t, `{"error":{"message":"rate limited"}}`, w.Body.String())
assert.Equal(t, "rate-limit", w.Header().Get("X-Upstream-Error"))
assert.Empty(t, w.Header().Get("Transfer-Encoding"))
}
func TestProxyHandler_StreamUpstreamNon2xx_Passthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).Return(&provider.StreamResponse{
StatusCode: http.StatusServiceUnavailable,
Headers: map[string]string{"Content-Type": "application/json", "Connection": "close"},
Body: []byte(`{"error":"upstream down"}`),
}, nil)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusServiceUnavailable, w.Code)
assert.JSONEq(t, `{"error":"upstream down"}`, w.Body.String())
assert.Empty(t, w.Header().Get("Connection"))
}
func TestFilterHopByHopHeaders(t *testing.T) {
filtered := filterHopByHopHeaders(map[string]string{
"Connection": "close",
"Transfer-Encoding": "chunked",
"Keep-Alive": "timeout=5",
"Proxy-Authenticate": "Basic",
"Proxy-Authorization": "Basic token",
"TE": "trailers",
"Trailer": "Expires",
"Upgrade": "websocket",
"Content-Type": "application/json",
"X-Request-ID": "req-1",
})
assert.Equal(t, map[string]string{
"Content-Type": "application/json",
"X-Request-ID": "req-1",
}, filtered)
}
func TestProxyHandler_UnknownInterface_DoesNotGuessModel(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Equal(t, "https://api.test.com/v1/unknown?trace=1", spec.URL)
assert.JSONEq(t, `{"model":"p1/gpt-4","payload":true}`, string(spec.Body))
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"ok":true}`),
}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/unknown"}}
c.Request = httptest.NewRequest("POST", "/openai/unknown?trace=1", bytes.NewReader([]byte(`{"model":"p1/gpt-4","payload":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.JSONEq(t, `{"ok":true}`, w.Body.String())
}
func TestProxyHandler_InvalidJSON_UsesGatewayError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
client := mocks.NewMockProviderClient(ctrl)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":`)))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
}
func TestProxyHandler_CrossProtocolMultimodal_Unsupported(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("anthropic_p", "claude").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.test", Protocol: "anthropic", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"anthropic_p/claude","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "UNSUPPORTED_MULTIMODAL")
}
func TestProxyHandler_SameProtocolMultimodal_SmartPassthrough(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
assert.Contains(t, string(spec.Body), "image_url")
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
return &conversion.HTTPResponseSpec{
StatusCode: http.StatusOK,
Headers: map[string]string{"Content-Type": "application/json"},
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
}, nil
})
providerSvc := mocks.NewMockProviderService(ctrl)
statsSvc := mocks.NewMockStatsService(ctrl)
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
body := []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "p1/gpt-4")
}
func TestProxyHandler_RawStreamPassthrough_PreservesSSEFrames(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
engine := setupProxyEngine(t)
routingSvc := mocks.NewMockRoutingService(ctrl)
providerSvc := mocks.NewMockProviderService(ctrl)
providerSvc.EXPECT().List().Return([]domain.Provider{
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
}, nil)
client := mocks.NewMockProviderClient(ctrl)
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
ch := make(chan provider.StreamEvent, 3)
go func() {
defer close(ch)
ch <- provider.StreamEvent{Data: []byte("data: {\"model\":\"gpt-4\",\"choices\":[]}\n\n")}
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
ch <- provider.StreamEvent{Done: true}
}()
return &provider.StreamResponse{StatusCode: http.StatusOK, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
})
statsSvc := mocks.NewMockStatsService(ctrl)
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
h.HandleProxy(c)
require.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "data: {\"model\":\"gpt-4\",\"choices\":[]}\n\ndata: [DONE]\n\n", w.Body.String())
}

View File

@@ -5,9 +5,9 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gin-gonic/gin"
"nex/backend/internal/service" "nex/backend/internal/service"
"github.com/gin-gonic/gin"
) )
// StatsHandler 统计处理器 // StatsHandler 统计处理器

View File

@@ -0,0 +1,26 @@
package handler
import (
"net/http"
"nex/backend/pkg/buildinfo"
"github.com/gin-gonic/gin"
)
// VersionHandler 提供后端构建版本信息。
type VersionHandler struct{}
// NewVersionHandler 创建版本信息处理器。
func NewVersionHandler() *VersionHandler {
return &VersionHandler{}
}
// GetVersion 返回构建注入的版本元数据。
func (h *VersionHandler) GetVersion(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"version": buildinfo.Version(),
"commit": buildinfo.Commit(),
"build_time": buildinfo.BuildTime(),
})
}

View File

@@ -0,0 +1,31 @@
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVersionHandler_GetVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
h := NewVersionHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/version", nil)
h.GetVersion(c)
assert.Equal(t, http.StatusOK, w.Code)
var result map[string]string
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
assert.Equal(t, "dev", result["version"])
assert.Equal(t, "unknown", result["commit"])
assert.Equal(t, "unknown", result["build_time"])
}

View File

@@ -8,6 +8,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"strings"
"syscall" "syscall"
"time" "time"
@@ -43,6 +44,14 @@ type StreamEvent struct {
Done bool Done bool
} }
// StreamResponse 表示上游流式 HTTP 响应。
type StreamResponse struct {
StatusCode int
Headers map[string]string
Body []byte
Events <-chan StreamEvent
}
// Client 协议无关的供应商客户端 // Client 协议无关的供应商客户端
type Client struct { type Client struct {
httpClient *http.Client httpClient *http.Client
@@ -51,10 +60,11 @@ type Client struct {
} }
// ProviderClient 供应商客户端接口 // ProviderClient 供应商客户端接口
//
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks //go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
type ProviderClient interface { type ProviderClient interface {
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
} }
// NewClient 创建供应商客户端 // NewClient 创建供应商客户端
@@ -115,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
} }
// SendStream 发送流式请求 // SendStream 发送流式请求
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) { func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
var bodyReader io.Reader var bodyReader io.Reader
if len(spec.Body) > 0 { if len(spec.Body) > 0 {
bodyReader = bytes.NewReader(spec.Body) bodyReader = bytes.NewReader(spec.Body)
@@ -138,20 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
return nil, pkgErrors.ErrRequestSend.WithCause(err) return nil, pkgErrors.ErrRequestSend.WithCause(err)
} }
if resp.StatusCode != http.StatusOK { respHeaders := extractResponseHeaders(resp.Header)
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
defer resp.Body.Close() defer resp.Body.Close()
cancel() cancel()
errBody, _ := io.ReadAll(resp.Body) errBody, readErr := io.ReadAll(resp.Body)
if len(errBody) > 0 { if readErr != nil {
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody)) return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
} }
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode) return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Body: errBody,
}, nil
} }
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize) eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
go c.readStream(streamCtx, cancel, resp.Body, eventChan) go c.readStream(streamCtx, cancel, resp.Body, eventChan)
return eventChan, nil return &StreamResponse{
StatusCode: resp.StatusCode,
Headers: respHeaders,
Events: eventChan,
}, nil
} }
// readStream 读取 SSE 流 // readStream 读取 SSE 流
@@ -184,7 +203,7 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
if isNetworkError(err) { if isNetworkError(err) {
c.logger.Error("流网络错误", zap.String("error", err.Error())) c.logger.Error("流网络错误", zap.Error(err))
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)} eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
} else { } else {
c.logger.Error("流读取错误", zap.Error(err)) c.logger.Error("流读取错误", zap.Error(err))
@@ -204,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
} }
for { for {
idx := bytes.Index(dataBuf, []byte("\n\n")) idx, sepLen := findSSEFrameSeparator(dataBuf)
if idx == -1 { if idx == -1 {
break break
} }
rawEvent := dataBuf[:idx] frameEnd := idx + sepLen
dataBuf = dataBuf[idx+2:] rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
dataBuf = dataBuf[frameEnd:]
if bytes.Contains(rawEvent, []byte("data: [DONE]")) { if isSSEDoneFrame(rawEvent) {
eventChan <- StreamEvent{Data: rawEvent}
eventChan <- StreamEvent{Done: true} eventChan <- StreamEvent{Done: true}
return return
} }
@@ -221,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
} }
if err == io.EOF { if err == io.EOF {
if len(dataBuf) > 0 {
eventChan <- StreamEvent{Data: dataBuf}
}
return return
} }
} }
} }
func isSSEDoneFrame(frame []byte) bool {
payload, ok := sseFrameDataPayload(frame)
return ok && strings.TrimSpace(payload) == "[DONE]"
}
func sseFrameDataPayload(frame []byte) (string, bool) {
text := strings.TrimRight(string(frame), "\r\n")
lines := strings.Split(text, "\n")
var dataLines []string
for _, line := range lines {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
value := strings.TrimPrefix(line, "data:")
if strings.HasPrefix(value, " ") {
value = value[1:]
}
dataLines = append(dataLines, value)
}
}
if len(dataLines) == 0 {
return "", false
}
return strings.Join(dataLines, "\n"), true
}
func extractResponseHeaders(header http.Header) map[string]string {
respHeaders := make(map[string]string)
for k, vs := range header {
if len(vs) > 0 {
respHeaders[k] = vs[0]
}
}
return respHeaders
}
func findSSEFrameSeparator(data []byte) (int, int) {
lf := bytes.Index(data, []byte("\n\n"))
crlf := bytes.Index(data, []byte("\r\n\r\n"))
switch {
case lf < 0 && crlf < 0:
return -1, 0
case lf < 0:
return crlf, 4
case crlf < 0:
return lf, 2
case crlf <= lf:
return crlf, 4
default:
return lf, 2
}
}
// isNetworkError 判断是否为网络相关错误 // isNetworkError 判断是否为网络相关错误
func isNetworkError(err error) bool { func isNetworkError(err error) bool {
if err == nil { if err == nil {

View File

@@ -41,7 +41,8 @@ func TestClient_Send_Success(t *testing.T) {
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"test","model":"gpt-4"}`)) _, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -65,7 +66,8 @@ func TestClient_Send_Success(t *testing.T) {
func TestClient_Send_ErrorResponse(t *testing.T) { func TestClient_Send_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`)) _, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -108,11 +110,13 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, eventChan) require.NotNil(t, streamResp)
require.Equal(t, http.StatusOK, streamResp.StatusCode)
require.NotNil(t, streamResp.Events)
for range eventChan { for range streamResp.Events {
} }
} }
@@ -130,8 +134,10 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
_, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
assert.Error(t, err) require.NoError(t, err)
require.NotNil(t, streamResp)
assert.Equal(t, http.StatusInternalServerError, streamResp.StatusCode)
} }
func TestClient_SendStream_SSEEvents(t *testing.T) { func TestClient_SendStream_SSEEvents(t *testing.T) {
@@ -140,12 +146,15 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")) _, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -159,24 +168,73 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`), Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte var dataEvents [][]byte
var doneEvents int var doneEvents int
for event := range eventChan { for event := range streamResp.Events {
if event.Done { switch {
case event.Done:
doneEvents++ doneEvents++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataEvents = append(dataEvents, event.Data) dataEvents = append(dataEvents, event.Data)
} }
} }
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream") assert.Equal(t, 3, len(dataEvents), "expected 2 data frames plus DONE frame from SSE stream")
assert.Contains(t, string(dataEvents[0]), "Hello") assert.Contains(t, string(dataEvents[0]), "Hello")
assert.Contains(t, string(dataEvents[1]), "World") assert.Contains(t, string(dataEvents[1]), "World")
assert.Contains(t, string(dataEvents[2]), "[DONE]")
assert.Equal(t, 1, doneEvents)
assert.Contains(t, string(dataEvents[0]), "\n\n")
}
func TestClient_SendStream_DONEOnlyWhenDataPayloadEqualsDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher)
require.True(t, ok)
_, err := w.Write([]byte("data: {\"text\":\"data: [DONE] is plain text\"}\n\n"))
require.NoError(t, err)
flusher.Flush()
_, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush()
}))
defer server.Close()
client := NewClient(zap.NewNop())
spec := conversion.HTTPRequestSpec{
URL: server.URL + "/v1/chat/completions",
Method: "POST",
Body: []byte(`{}`),
}
streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err)
require.NotNil(t, streamResp)
var dataEvents [][]byte
var doneEvents int
for event := range streamResp.Events {
switch {
case event.Done:
doneEvents++
case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error)
default:
dataEvents = append(dataEvents, event.Data)
}
}
require.Len(t, dataEvents, 2)
assert.Contains(t, string(dataEvents[0]), "plain text")
assert.Contains(t, string(dataEvents[1]), "[DONE]")
assert.Equal(t, 1, doneEvents) assert.Equal(t, 1, doneEvents)
} }
@@ -197,13 +255,13 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(ctx, spec) streamResp, err := client.SendStream(ctx, spec)
require.NoError(t, err) require.NoError(t, err)
cancel() cancel()
var gotError bool var gotError bool
for event := range eventChan { for event := range streamResp.Events {
if event.Error != nil { if event.Error != nil {
gotError = true gotError = true
} }
@@ -215,7 +273,8 @@ func TestClient_Send_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method) assert.Equal(t, "GET", r.Method)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"result":"ok"}`)) _, err := w.Write([]byte(`{"result":"ok"}`))
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
@@ -238,10 +297,12 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
})) }))
@@ -255,21 +316,22 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
var dataCount int var dataCount int
var doneCount int var doneCount int
for event := range eventChan { for event := range streamResp.Events {
if event.Done { switch {
case event.Done:
doneCount++ doneCount++
} else if event.Error != nil { case event.Error != nil:
t.Fatalf("unexpected error: %v", event.Error) t.Fatalf("unexpected error: %v", event.Error)
} else { default:
dataCount++ dataCount++
} }
} }
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE") assert.Equal(t, 2, dataCount, "expected 1 data frame plus DONE frame from slow SSE")
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE") assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
} }
@@ -279,10 +341,12 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
w.Write([]byte("data: [DONE]\n\n")) _, err = w.Write([]byte("data: [DONE]\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
})) }))
@@ -296,19 +360,19 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
var dataEvents int var dataEvents int
var doneEvents int var doneEvents int
for event := range eventChan { for event := range streamResp.Events {
if event.Done { if event.Done {
doneEvents++ doneEvents++
} else { } else {
dataEvents++ dataEvents++
} }
} }
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE") assert.Equal(t, 3, dataEvents, "expected 2 data frames plus DONE frame from split SSE")
assert.Equal(t, 1, doneEvents) assert.Equal(t, 1, doneEvents)
} }
@@ -364,13 +428,14 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
require.True(t, ok) require.True(t, ok)
w.Write([]byte("data: {\"id\":\"1\"}\n\n")) _, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
require.NoError(t, err)
flusher.Flush() flusher.Flush()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
if hijacker, ok := w.(http.Hijacker); ok { if hijacker, ok := w.(http.Hijacker); ok {
conn, _, _ := hijacker.Hijack() conn, _, _ := hijacker.Hijack()
if conn != nil { if conn != nil {
conn.Close() require.NoError(t, conn.Close())
} }
} }
})) }))
@@ -384,11 +449,11 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
Body: []byte(`{}`), Body: []byte(`{}`),
} }
eventChan, err := client.SendStream(context.Background(), spec) streamResp, err := client.SendStream(context.Background(), spec)
require.NoError(t, err) require.NoError(t, err)
var gotData bool var gotData bool
for event := range eventChan { for event := range streamResp.Events {
if event.Error != nil { if event.Error != nil {
} else if !event.Done { } else if !event.Done {
gotData = true gotData = true

View File

@@ -3,10 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -3,13 +3,13 @@ package repository
import ( import (
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
testHelpers "nex/backend/tests" testHelpers "nex/backend/tests"
"nex/backend/internal/domain"
) )
func setupTestDB(t *testing.T) *gorm.DB { func setupTestDB(t *testing.T) *gorm.DB {

View File

@@ -3,11 +3,11 @@ package repository
import ( import (
"time" "time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"nex/backend/internal/config" "nex/backend/internal/config"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type statsRepository struct { type statsRepository struct {
@@ -19,8 +19,8 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
} }
func (r *statsRepository) Record(providerID, modelName string) error { func (r *statsRepository) Record(providerID, modelName string) error {
today := time.Now().Format("2006-01-02") now := time.Now()
todayTime, _ := time.Parse("2006-01-02", today) todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
stats := config.UsageStats{ stats := config.UsageStats{
ProviderID: providerID, ProviderID: providerID,

View File

@@ -1,11 +1,15 @@
package service package service
import ( import (
"github.com/google/uuid" "errors"
appErrors "nex/backend/pkg/errors"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"github.com/google/uuid"
"gorm.io/gorm"
appErrors "nex/backend/pkg/errors"
) )
type modelService struct { type modelService struct {
@@ -108,7 +112,11 @@ func (s *modelService) Delete(id string) error {
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error { func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName) existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
if err != nil { if err != nil {
return nil // 未找到,不重复 if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 未找到,不重复
}
return err
} }
if excludeID != "" && existing.ID == excludeID { if excludeID != "" && existing.ID == excludeID {
return nil // 排除自身 return nil // 排除自身

View File

@@ -3,10 +3,10 @@ package service
import ( import (
"strings" "strings"
"nex/backend/pkg/modelid"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"nex/backend/pkg/modelid"
appErrors "nex/backend/pkg/errors" appErrors "nex/backend/pkg/errors"
) )

View File

@@ -4,10 +4,11 @@ import (
"strings" "strings"
"sync" "sync"
"go.uber.org/zap"
"nex/backend/internal/domain" "nex/backend/internal/domain"
"nex/backend/internal/repository" "nex/backend/internal/repository"
"go.uber.org/zap"
pkglogger "nex/backend/pkg/logger" pkglogger "nex/backend/pkg/logger"
) )
@@ -34,7 +35,9 @@ func NewRoutingCache(
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) { func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
provider, err := c.providerRepo.GetByID(id) provider, err := c.providerRepo.GetByID(id)
@@ -43,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
} }
if v, ok := c.providers.Load(id); ok { if v, ok := c.providers.Load(id); ok {
return v.(*domain.Provider), nil if provider, ok := v.(*domain.Provider); ok {
return provider, nil
}
} }
c.providers.Store(id, provider) c.providers.Store(id, provider)
@@ -54,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
key := providerID + "/" + modelName key := providerID + "/" + modelName
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName) model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
@@ -63,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
} }
if v, ok := c.models.Load(key); ok { if v, ok := c.models.Load(key); ok {
return v.(*domain.Model), nil if model, ok := v.(*domain.Model); ok {
return model, nil
}
} }
c.models.Store(key, model) c.models.Store(key, model)
@@ -97,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
prefix := providerID + "/" prefix := providerID + "/"
count := 0 count := 0
c.models.Range(func(key, value interface{}) bool { c.models.Range(func(key, value interface{}) bool {
if strings.HasPrefix(key.(string), prefix) { keyStr, ok := key.(string)
if !ok {
return true
}
if strings.HasPrefix(keyStr, prefix) {
c.models.Delete(key) c.models.Delete(key)
count++ count++
} }

View File

@@ -5,11 +5,11 @@ import (
"sync" "sync"
"testing" "testing"
"nex/backend/internal/domain"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/zap"
"nex/backend/internal/domain"
) )
type mockModelRepo struct { type mockModelRepo struct {
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
var openaiCount, anthropicCount int var openaiCount, anthropicCount int
cache.models.Range(func(key, value interface{}) bool { cache.models.Range(func(key, value interface{}) bool {
if key.(string) == "anthropic/claude" { keyStr, ok := key.(string)
if ok && keyStr == "anthropic/claude" {
anthropicCount++ anthropicCount++
} }
return true return true

Some files were not shown because too many files have changed in this diff Show More