Compare commits
98 Commits
9359ca7f62
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 2c401f7ae6 | |||
| a9972360c2 | |||
| b00fa4dcee | |||
| 92525b39c3 | |||
| 38a2555c7b | |||
| 9622d44aac | |||
| 155244433f | |||
| 2c043c6cf7 | |||
| f5c82b6980 | |||
| 9105a36097 | |||
| f1ee646ca4 | |||
| b9b487c591 | |||
| 4c62c071fb | |||
| b2e9dd8b7f | |||
| d143c5f3df | |||
| 4eebdfb8db | |||
| b517946585 | |||
| 4ddae6be74 | |||
| 195762ff97 | |||
| bcf5ca89e5 | |||
| 365943e4c4 | |||
| 4c6b49099d | |||
| 4c78ab6cc8 | |||
| 52007c9461 | |||
| 086dd1fed7 | |||
| 1d7e839b49 | |||
| fa7babf13b | |||
| 280099b89c | |||
| 0a92a25451 | |||
| 8c075194e5 | |||
| 53e477d383 | |||
| 1522c87c74 | |||
| e0d05c9869 | |||
| 5b401e29cb | |||
| 65ac9f740a | |||
| 58ebcaa299 | |||
| 5b765c8b5e | |||
| b3258e76df | |||
| 64dc66afa6 | |||
| 15f08ee2ca | |||
| 380586afa6 | |||
| ebb70809bf | |||
| 7399afbc5c | |||
| c0669e4b07 | |||
| 05c04091b3 | |||
| 0b05e08705 | |||
| df253559a5 | |||
| 669cbb8c51 | |||
| 5ae9d85272 | |||
| 72aebef625 | |||
| f5e45d032e | |||
| b03e5f809f | |||
| ec563aaa16 | |||
| 873f09d3bf | |||
| 5e7267db07 | |||
| 7b28cee7a1 | |||
| 934c8dea77 | |||
| 7d91fe345e | |||
| 4e86adffb7 | |||
| 5d58acf5a6 | |||
| 81dcecb723 | |||
| 141f5f886f | |||
| 7fa5af483b | |||
| f488b9cc15 | |||
| 59179094ed | |||
| 4fc5fb4764 | |||
| feff97acbd | |||
| b7e205f4b6 | |||
| 24f03595a7 | |||
| 395887667d | |||
| 44d6af026a | |||
| 6e11ada42c | |||
| da790db75b | |||
| e1af978c56 | |||
| 980875ecf3 | |||
| 7f0f831226 | |||
| f3a207fa16 | |||
| 56ecc73d1b | |||
| 1ae9336cbe | |||
| 3fa5827de3 | |||
| cfb0edf802 | |||
| aea360bce8 | |||
| d92db73937 | |||
| bc1ee612d9 | |||
| 1dac347d3b | |||
| 26810d9410 | |||
| b14685d9a5 | |||
| 4dc518a5f4 | |||
| b92974716f | |||
| 2b1c5e96c3 | |||
| 6eeb38c15e | |||
| 49818ed4d8 | |||
| ddd284c1ca | |||
| c5b3d9dfc7 | |||
| 870004af23 | |||
| 5dd26d29a7 | |||
| 47ecbadc7c | |||
| 1580b5b838 |
7
.editorconfig
Normal file
7
.editorconfig
Normal file
@@ -0,0 +1,7 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
8
.gitattributes
vendored
Normal file
8
.gitattributes
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
* text=auto eol=lf
|
||||
|
||||
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
assets/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||
assets/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||
151
.github/workflows/release.yml
vendored
Normal file
151
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
name: Prepare Release
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.work
|
||||
|
||||
- name: Verify tag and VERSION
|
||||
id: version
|
||||
run: |
|
||||
version=$(go run ./backend/cmd/versionctl print)
|
||||
go run ./backend/cmd/versionctl verify-tag "${GITHUB_REF_NAME}"
|
||||
printf 'version=%s\n' "$version" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-linux:
|
||||
name: Build Linux Assets
|
||||
needs: prepare
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.work
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Install Linux desktop build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libayatana-appindicator3-dev libgtk-3-dev
|
||||
|
||||
- name: Build Linux release assets
|
||||
run: make release-assets-linux
|
||||
|
||||
- name: Upload Linux release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-linux
|
||||
path: build/release/*
|
||||
|
||||
build-windows:
|
||||
name: Build Windows Assets
|
||||
needs: prepare
|
||||
runs-on: windows-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.work
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Setup MSYS2 toolchain
|
||||
uses: msys2/setup-msys2@v2
|
||||
with:
|
||||
update: true
|
||||
install: >-
|
||||
make
|
||||
mingw-w64-x86_64-gcc
|
||||
|
||||
- name: Build Windows release assets
|
||||
shell: msys2 {0}
|
||||
run: make release-assets-windows
|
||||
|
||||
- name: Upload Windows release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-windows
|
||||
path: build/release/*
|
||||
|
||||
build-macos:
|
||||
name: Build macOS Assets
|
||||
needs: prepare
|
||||
runs-on: macos-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.work
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Build macOS release assets
|
||||
run: make release-assets-macos
|
||||
|
||||
- name: Upload macOS release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-macos
|
||||
path: build/release/*
|
||||
|
||||
draft-release:
|
||||
name: Create Draft Release
|
||||
needs: [prepare, build-linux, build-windows, build-macos]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Download release assets
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: release-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
|
||||
- name: Publish draft release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
name: v${{ needs.prepare.outputs.version }}
|
||||
tag_name: ${{ github.ref_name }}
|
||||
draft: true
|
||||
files: |
|
||||
dist/*
|
||||
95
.gitignore
vendored
95
.gitignore
vendored
@@ -182,6 +182,10 @@ build/Release
|
||||
node_modules/
|
||||
jspm_packages/
|
||||
|
||||
# Test
|
||||
playwright-report
|
||||
test-results
|
||||
|
||||
# TypeScript v1 declaration files
|
||||
typings/
|
||||
|
||||
@@ -313,8 +317,99 @@ Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### Python.gitignore ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.python-version
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
# Pyre
|
||||
.pyre/
|
||||
|
||||
# pytype
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# Custom
|
||||
.claude
|
||||
.opencode
|
||||
.codex
|
||||
openspec/changes/archive
|
||||
temp
|
||||
.agents
|
||||
skills-lock.json
|
||||
.worktrees
|
||||
!scripts/build/
|
||||
|
||||
# Embedfs generated
|
||||
embedfs/assets/
|
||||
embedfs/frontend-dist/
|
||||
backend/cmd/desktop/rsrc_windows_*.syso
|
||||
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"files.eol": "\n"
|
||||
}
|
||||
184
LICENSE
Normal file
184
LICENSE
Normal file
@@ -0,0 +1,184 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, and
|
||||
distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the copyright
|
||||
owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other entities
|
||||
that control, are controlled by, or are under common control with that entity.
|
||||
For the purposes of this definition, "control" means (i) the power, direct or
|
||||
indirect, to cause the direction or management of such entity, whether by
|
||||
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising
|
||||
permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications, including
|
||||
but not limited to software source code, documentation source, and configuration
|
||||
files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation or
|
||||
translation of a Source form, including but not limited to compiled object code,
|
||||
generated documentation, and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object form,
|
||||
made available under the License, as indicated by a copyright notice that is
|
||||
included in or attached to the work (an example is provided in the Appendix
|
||||
below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form, that
|
||||
is based on (or derived from) the Work and for which the editorial revisions,
|
||||
annotations, elaborations, or other modifications represent, as a whole, an
|
||||
original work of authorship. For the purposes of this License, Derivative Works
|
||||
shall not include works that remain separable from, or merely link (or bind by
|
||||
name) to the interfaces of, the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original version
|
||||
of the Work and any modifications or additions to that Work or Derivative Works
|
||||
thereof, that is intentionally submitted to Licensor for inclusion in the Work
|
||||
by the copyright owner or by an individual or Legal Entity authorized to submit
|
||||
on behalf of the copyright owner. For the purposes of this definition,
|
||||
"submitted" means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems, and
|
||||
issue tracking systems that are managed by, or on behalf of, the Licensor for
|
||||
the purpose of discussing and improving the Work, but excluding communication
|
||||
that is conspicuously marked or otherwise designated in writing by the copyright
|
||||
owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
|
||||
of whom a Contribution has been received by Licensor and subsequently
|
||||
incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this
|
||||
License, each Contributor hereby grants to You a perpetual, worldwide,
|
||||
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
|
||||
reproduce, prepare Derivative Works of, publicly display, publicly perform,
|
||||
sublicense, and distribute the Work and such Derivative Works in Source or
|
||||
Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License,
|
||||
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section) patent
|
||||
license to make, have made, use, offer to sell, sell, import, and otherwise
|
||||
transfer the Work, where such license applies only to those patent claims
|
||||
licensable by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s) with the Work
|
||||
to which such Contribution(s) was submitted. If You institute patent litigation
|
||||
against any entity (including a cross-claim or counterclaim in a lawsuit)
|
||||
alleging that the Work or a Contribution incorporated within the Work
|
||||
constitutes direct or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate as of the date
|
||||
such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or
|
||||
Derivative Works thereof in any medium, with or without modifications, and in
|
||||
Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of
|
||||
this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices stating that
|
||||
You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works that You
|
||||
distribute, all copyright, patent, trademark, and attribution notices from the
|
||||
Source form of the Work, excluding those notices that do not pertain to any part
|
||||
of the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
|
||||
any Derivative Works that You distribute must include a readable copy of the
|
||||
attribution notices contained within such NOTICE file, excluding those notices
|
||||
that do not pertain to any part of the Derivative Works, in at least one of the
|
||||
following places: within a NOTICE text file distributed as part of the
|
||||
Derivative Works; within the Source form or documentation, if provided along
|
||||
with the Derivative Works; or, within a display generated by the Derivative
|
||||
Works, if and wherever such third-party notices normally appear. The contents of
|
||||
the NOTICE file are for informational purposes only and do not modify the
|
||||
License. You may add Your own attribution notices within Derivative Works that
|
||||
You distribute, alongside or as an addendum to the NOTICE text from the Work,
|
||||
provided that such additional attribution notices cannot be construed as
|
||||
modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide
|
||||
additional or different license terms and conditions for use, reproduction, or
|
||||
distribution of Your modifications, or for any such Derivative Works as a whole,
|
||||
provided Your use, reproduction, and distribution of the Work otherwise complies
|
||||
with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any
|
||||
Contribution intentionally submitted for inclusion in the Work by You to the
|
||||
Licensor shall be under the terms and conditions of this License, without any
|
||||
additional terms or conditions. Notwithstanding the above, nothing herein shall
|
||||
supersede or modify the terms of any separate license agreement you may have
|
||||
executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade names,
|
||||
trademarks, service marks, or product names of the Licensor, except as required
|
||||
for reasonable and customary use in describing the origin of the Work and
|
||||
reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
|
||||
writing, Licensor provides the Work (and each Contributor provides its
|
||||
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied, including, without limitation, any warranties
|
||||
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any risks
|
||||
associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, whether in
|
||||
tort (including negligence), contract, or otherwise, unless required by
|
||||
applicable law (such as deliberate and grossly negligent acts) or agreed to in
|
||||
writing, shall any Contributor be liable to You for damages, including any
|
||||
direct, indirect, special, incidental, or consequential damages of any character
|
||||
arising as a result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill, work stoppage,
|
||||
computer failure or malfunction, or any and all other commercial damages or
|
||||
losses), even if such Contributor has been advised of the possibility of such
|
||||
damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or
|
||||
Derivative Works thereof, You may choose to offer, and charge a fee for,
|
||||
acceptance of support, warranty, indemnity, or other liability obligations
|
||||
and/or rights consistent with this License. However, in accepting such
|
||||
obligations, You may act only on Your own behalf and on Your sole
|
||||
responsibility, not on behalf of any other Contributor, and only if You agree to
|
||||
indemnify, defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason of your
|
||||
accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following boilerplate
|
||||
notice, with the fields enclosed by brackets "[]" replaced with your own
|
||||
identifying information. (Don't include the brackets!) The text should be
|
||||
enclosed in the appropriate comment syntax for the file format. We also
|
||||
recommend that a file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier identification within
|
||||
third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
232
Makefile
Normal file
232
Makefile
Normal file
@@ -0,0 +1,232 @@
|
||||
.PHONY: \
|
||||
lint test clean \
|
||||
version-sync version-check \
|
||||
server-run server-build server-lint server-test server-clean \
|
||||
desktop-build-mac desktop-build-win desktop-build-linux \
|
||||
desktop-lint desktop-test desktop-clean \
|
||||
release-assets-linux release-assets-windows release-assets-macos \
|
||||
_backend-lint _backend-test _backend-clean _backend-build \
|
||||
_frontend-install _frontend-build _frontend-check _frontend-test _frontend-dev _frontend-clean \
|
||||
_desktop-test _desktop-clean _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource \
|
||||
_server-run-backend _server-run-frontend
|
||||
|
||||
VERSION := $(shell go run ./backend/cmd/versionctl print)
|
||||
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||
GO_LDFLAGS_WIN := $(GO_LDFLAGS) -H=windowsgui
|
||||
RELEASE_DIR := build/release
|
||||
SERVER_LINUX_ASSET := $(shell go run ./backend/cmd/versionctl asset-name server linux amd64)
|
||||
SERVER_WINDOWS_ASSET := $(shell go run ./backend/cmd/versionctl asset-name server windows amd64)
|
||||
SERVER_DARWIN_AMD64_ASSET := $(shell go run ./backend/cmd/versionctl asset-name server darwin amd64)
|
||||
SERVER_DARWIN_ARM64_ASSET := $(shell go run ./backend/cmd/versionctl asset-name server darwin arm64)
|
||||
DESKTOP_LINUX_ASSET := $(shell go run ./backend/cmd/versionctl asset-name desktop linux)
|
||||
DESKTOP_WINDOWS_ASSET := $(shell go run ./backend/cmd/versionctl asset-name desktop windows)
|
||||
DESKTOP_MACOS_ASSET := $(shell go run ./backend/cmd/versionctl asset-name desktop macos)
|
||||
|
||||
# ============================================
|
||||
# 全局命令
|
||||
# ============================================
|
||||
|
||||
lint: _backend-lint _frontend-check
|
||||
@printf 'Lint complete\n'
|
||||
|
||||
test: _backend-test _frontend-test _desktop-test
|
||||
@printf 'All tests passed\n'
|
||||
|
||||
clean: _backend-clean _frontend-clean _desktop-clean
|
||||
@printf 'Clean complete\n'
|
||||
|
||||
# ============================================
|
||||
# 版本管理
|
||||
# ============================================
|
||||
|
||||
version-sync:
|
||||
go run ./backend/cmd/versionctl sync
|
||||
|
||||
version-check:
|
||||
go run ./backend/cmd/versionctl check
|
||||
|
||||
# ============================================
|
||||
# Server 模式
|
||||
# ============================================
|
||||
|
||||
server-run:
|
||||
@$(MAKE) -j2 _server-run-backend _server-run-frontend
|
||||
|
||||
server-build: version-check _backend-build _frontend-build
|
||||
@printf 'Server build complete\n'
|
||||
|
||||
server-lint: _backend-lint _frontend-check
|
||||
@printf 'Server lint complete\n'
|
||||
|
||||
server-test: _backend-test _frontend-test
|
||||
@printf 'Server tests passed\n'
|
||||
|
||||
server-clean: _backend-clean _frontend-clean
|
||||
@printf 'Server artifacts cleaned\n'
|
||||
|
||||
_server-run-backend:
|
||||
@$(MAKE) -C backend run
|
||||
|
||||
_server-run-frontend: _frontend-install
|
||||
cd frontend && bun run dev
|
||||
|
||||
# ============================================
|
||||
# Desktop 模式
|
||||
# ============================================
|
||||
|
||||
desktop-build-mac: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
|
||||
@printf 'Building macOS desktop...\n'
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-arm64 ./cmd/desktop
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-amd64 ./cmd/desktop
|
||||
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
|
||||
@printf 'Packaging macOS app bundle...\n'
|
||||
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
|
||||
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
|
||||
@if [ -f assets/icon.icns ]; then \
|
||||
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
|
||||
else \
|
||||
printf 'Missing assets/icon.icns\n'; \
|
||||
fi
|
||||
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
|
||||
if [ -z "$$MIN_MACOS_VERSION" ]; then \
|
||||
printf 'Unable to read macOS minimum version\n'; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
go run ./backend/cmd/versionctl macos-plist "$$MIN_MACOS_VERSION" > build/Nex.app/Contents/Info.plist
|
||||
chmod +x build/Nex.app/Contents/MacOS/nex
|
||||
@printf 'macOS desktop build complete\n'
|
||||
|
||||
desktop-build-win: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource
|
||||
@printf 'Building Windows desktop...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "New-Item -ItemType Directory -Path 'build' -Force | Out-Null"
|
||||
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
|
||||
else
|
||||
mkdir -p build
|
||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
|
||||
endif
|
||||
@printf 'Windows desktop build complete\n'
|
||||
|
||||
desktop-build-linux: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
|
||||
@printf 'Building Linux desktop...\n'
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-linux-amd64 ./cmd/desktop
|
||||
@printf 'Linux desktop build complete\n'
|
||||
|
||||
desktop-lint: _backend-lint _frontend-check
|
||||
@printf 'Desktop lint complete\n'
|
||||
|
||||
desktop-test: _desktop-test
|
||||
@printf 'Desktop tests passed\n'
|
||||
|
||||
desktop-clean: _desktop-clean
|
||||
@printf 'Desktop artifacts cleaned\n'
|
||||
|
||||
_desktop-test:
|
||||
cd backend && go test ./cmd/desktop/... -v
|
||||
|
||||
_desktop-clean:
|
||||
rm -rf build/ embedfs/assets embedfs/frontend-dist backend/cmd/desktop/rsrc_windows_amd64.syso
|
||||
|
||||
_desktop-prepare-frontend: _frontend-install
|
||||
@printf 'Preparing frontend for desktop...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "Copy-Item -LiteralPath 'frontend/.env.desktop' -Destination 'frontend/.env.production.local' -Force"
|
||||
cd frontend && bun run build
|
||||
powershell -NoProfile -Command "Remove-Item -LiteralPath 'frontend/.env.production.local' -Force -ErrorAction SilentlyContinue"
|
||||
else
|
||||
cd frontend && cp .env.desktop .env.production.local
|
||||
cd frontend && bun run build
|
||||
rm -f frontend/.env.production.local
|
||||
endif
|
||||
|
||||
_desktop-prepare-embedfs:
|
||||
@printf 'Preparing embedded filesystem...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "Remove-Item -LiteralPath 'embedfs/assets' -Recurse -Force -ErrorAction SilentlyContinue; Remove-Item -LiteralPath 'embedfs/frontend-dist' -Recurse -Force -ErrorAction SilentlyContinue; Copy-Item -LiteralPath 'assets' -Destination 'embedfs/assets' -Recurse; Copy-Item -LiteralPath 'frontend/dist' -Destination 'embedfs/frontend-dist' -Recurse"
|
||||
else
|
||||
rm -rf embedfs/assets embedfs/frontend-dist
|
||||
cp -r assets embedfs/assets
|
||||
cp -r frontend/dist embedfs/frontend-dist
|
||||
endif
|
||||
|
||||
_desktop-prepare-windows-resource:
|
||||
@printf 'Preparing Windows executable icon...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso
|
||||
else
|
||||
@if command -v x86_64-w64-mingw32-windres >/dev/null 2>&1; then \
|
||||
cd backend/cmd/desktop && x86_64-w64-mingw32-windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
|
||||
elif command -v windres >/dev/null 2>&1; then \
|
||||
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
|
||||
else \
|
||||
printf 'Missing windres for Windows icon resource generation\n'; \
|
||||
exit 1; \
|
||||
fi
|
||||
endif
|
||||
|
||||
# ============================================
|
||||
# 发布资产
|
||||
# ============================================
|
||||
|
||||
release-assets-linux: version-check desktop-build-linux
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-linux-amd64 ./cmd/server
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_LINUX_ASSET)" nex-server-linux-amd64
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(DESKTOP_LINUX_ASSET)" nex-linux-amd64
|
||||
|
||||
release-assets-windows: version-check desktop-build-win
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "Remove-Item -LiteralPath '$(RELEASE_DIR)' -Recurse -Force -ErrorAction SilentlyContinue; New-Item -ItemType Directory -Path '$(RELEASE_DIR)' -Force | Out-Null"
|
||||
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-server-win-amd64.exe ./cmd/server
|
||||
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-server-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(SERVER_WINDOWS_ASSET)' -Force"
|
||||
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(DESKTOP_WINDOWS_ASSET)' -Force"
|
||||
else
|
||||
@printf 'release-assets-windows requires Windows\n'
|
||||
@exit 1
|
||||
endif
|
||||
|
||||
release-assets-macos: version-check desktop-build-mac
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-amd64 ./cmd/server
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-arm64 ./cmd/server
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_AMD64_ASSET)" nex-server-darwin-amd64
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_ARM64_ASSET)" nex-server-darwin-arm64
|
||||
ditto -c -k --keepParent build/Nex.app "$(RELEASE_DIR)/$(DESKTOP_MACOS_ASSET)"
|
||||
|
||||
# ============================================
|
||||
# 共享 helper targets
|
||||
# ============================================
|
||||
|
||||
_backend-build:
|
||||
@$(MAKE) -C backend build
|
||||
|
||||
_backend-lint:
|
||||
@$(MAKE) -C backend lint
|
||||
|
||||
_backend-test:
|
||||
@$(MAKE) -C backend test
|
||||
|
||||
_backend-clean:
|
||||
@$(MAKE) -C backend clean
|
||||
|
||||
_frontend-install:
|
||||
cd frontend && bun install
|
||||
|
||||
_frontend-build: _frontend-install
|
||||
cd frontend && bun run build
|
||||
|
||||
_frontend-check: _frontend-install
|
||||
cd frontend && bun run check
|
||||
|
||||
_frontend-test: _frontend-install
|
||||
cd frontend && bun run test
|
||||
|
||||
_frontend-dev: _frontend-install
|
||||
cd frontend && bun run dev
|
||||
|
||||
_frontend-clean:
|
||||
rm -rf frontend/dist frontend/.next frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo
|
||||
379
README.md
379
README.md
@@ -6,106 +6,373 @@
|
||||
|
||||
```
|
||||
nex/
|
||||
├── backend/ # Go 后端服务
|
||||
│ ├── main.go
|
||||
│ ├── go.mod
|
||||
│ └── internal/
|
||||
│ ├── handler/ # HTTP 处理器
|
||||
│ ├── protocol/ # 协议适配器
|
||||
│ ├── provider/ # 供应商客户端
|
||||
│ ├── router/ # 模型路由
|
||||
│ ├── stats/ # 统计记录
|
||||
│ └── config/ # 配置与数据库
|
||||
├── backend/ # Go 后端服务(分层架构)
|
||||
│ ├── cmd/
|
||||
│ │ ├── server/ # CLI 主程序入口
|
||||
│ │ └── desktop/ # 桌面应用入口
|
||||
│ ├── internal/
|
||||
│ │ ├── handler/ # HTTP 处理器 + 中间件
|
||||
│ │ ├── service/ # 业务逻辑层
|
||||
│ │ ├── repository/ # 数据访问层
|
||||
│ │ ├── domain/ # 领域模型
|
||||
│ │ ├── conversion/ # 协议转换引擎(OpenAI/Anthropic 适配器)
|
||||
│ │ ├── provider/ # 供应商客户端
|
||||
│ │ └── config/ # 配置管理
|
||||
│ ├── pkg/ # 公共包(logger/errors/validator)
|
||||
│ ├── migrations/ # 数据库迁移
|
||||
│ └── tests/ # 测试(unit/integration)
|
||||
│
|
||||
├── frontend/ # React 前端界面
|
||||
├── frontend/ # React 前端界面
|
||||
│ ├── src/
|
||||
│ │ ├── api/ # API 层(统一请求封装 + 字段转换)
|
||||
│ │ ├── hooks/ # TanStack Query hooks
|
||||
│ │ ├── components/ # 通用组件(AppLayout)
|
||||
│ │ ├── pages/ # 页面(Providers, Stats, NotFound)
|
||||
│ │ ├── routes/ # React Router 路由配置
|
||||
│ │ ├── types/ # TypeScript 类型定义
|
||||
│ │ └── __tests__/ # 测试(API、Hooks、组件)
|
||||
│ ├── e2e/ # Playwright E2E 测试
|
||||
│ │ ├── api/ # API 层(统一请求封装 + 字段转换)
|
||||
│ │ ├── hooks/ # TanStack Query hooks
|
||||
│ │ ├── components/ # 通用组件(AppLayout)
|
||||
│ │ ├── pages/ # 页面(Providers, Stats)
|
||||
│ │ ├── routes/ # React Router 路由配置
|
||||
│ │ ├── types/ # TypeScript 类型定义
|
||||
│ │ └── __tests__/ # 单元测试 + 组件测试
|
||||
│ ├── e2e/ # Playwright E2E 测试
|
||||
│ └── package.json
|
||||
│
|
||||
└── README.md # 本文件
|
||||
├── assets/ # 应用资源
|
||||
│ ├── icon.png # 托盘图标
|
||||
│ ├── icon.icns # macOS 应用图标
|
||||
│ └── icon.ico # Windows 应用图标
|
||||
│
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
||||
- **透明代理**:对 OpenAI 兼容供应商透传请求
|
||||
- **流式响应**:完整支持 SSE 流式传输
|
||||
- **跨协议转换**:Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
|
||||
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
||||
- **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
|
||||
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
|
||||
- **Function Calling**:支持工具调用(Tools)
|
||||
- **多供应商管理**:配置和管理多个供应商
|
||||
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
|
||||
- **扩展接口**:支持 Embeddings 和 Rerank 接口
|
||||
- **多供应商管理**:配置和管理多个供应商(供应商 ID 仅限字母、数字、下划线)
|
||||
- **用量统计**:按供应商、模型、日期统计请求数量
|
||||
- **Web 配置界面**:提供供应商和模型配置管理
|
||||
|
||||
## 技术栈
|
||||
|
||||
### 后端
|
||||
- **Go 1.21+**
|
||||
- **Gin** - HTTP 框架
|
||||
- **GORM** - ORM
|
||||
- **SQLite** - 数据库
|
||||
- **语言**: Go 1.26+
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite / MySQL
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转 + 模块标识)
|
||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
#### 日志模块标识规范
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger,日志输出格式为 `[module.name]`:
|
||||
|
||||
```
|
||||
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
|
||||
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
|
||||
```
|
||||
|
||||
模块命名规范:
|
||||
- 单一职责包:`database`、`config`
|
||||
- 多实体包:`handler.proxy`、`service.provider`
|
||||
- 子包:`handler.middleware`
|
||||
|
||||
### 前端
|
||||
- **Bun** - 运行时
|
||||
- **Vite** - 构建工具
|
||||
- **TypeScript** (strict mode) - 类型系统
|
||||
- **React** - UI 框架
|
||||
- **Ant Design 5** - UI 组件库
|
||||
- **React Router v7** - 路由
|
||||
- **TanStack Query v5** - 数据获取
|
||||
- **SCSS Modules** - 样式方案
|
||||
- **Vitest + Playwright** - 测试
|
||||
- **运行时**: Bun
|
||||
- **构建工具**: Vite
|
||||
- **语言**: TypeScript (strict mode)
|
||||
- **框架**: React
|
||||
- **UI 组件库**: TDesign React
|
||||
- **图表库**: Recharts
|
||||
- **路由**: React Router v7
|
||||
- **数据获取**: TanStack Query v5
|
||||
- **样式**: SCSS Modules
|
||||
- **测试**: Vitest + React Testing Library + Playwright
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 后端
|
||||
### 桌面应用(推荐)
|
||||
|
||||
**构建桌面应用**:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
go mod download
|
||||
go run main.go
|
||||
# macOS (arm64 + amd64,并打包为 .app)
|
||||
make desktop-build-mac
|
||||
|
||||
# Windows
|
||||
make desktop-build-win
|
||||
|
||||
# Linux
|
||||
make desktop-build-linux
|
||||
```
|
||||
|
||||
后端服务将在 `http://localhost:9826` 启动。
|
||||
**使用桌面应用**:
|
||||
- 双击启动应用(macOS: Nex.app,Windows: nex-win-amd64.exe,Linux: nex-linux-amd64)
|
||||
- 系统托盘图标出现,浏览器自动打开管理界面
|
||||
- 点击托盘图标显示菜单,可打开管理界面或退出
|
||||
- 关闭浏览器后服务继续运行,可通过托盘重新打开
|
||||
|
||||
### 前端
|
||||
**注意事项**:
|
||||
- 桌面应用需要 CGO 支持
|
||||
- macOS: 自带 Xcode Command Line Tools
|
||||
- Linux: 自带 gcc,部分桌面环境需要 `libappindicator3-dev`
|
||||
- Windows: 需要 MinGW-w64 或在 Windows 环境构建
|
||||
|
||||
**Linux 桌面环境兼容性**:
|
||||
- GNOME: 需要 AppIndicator 扩展
|
||||
- KDE Plasma: 原生支持
|
||||
- Xfce: 需要 libappindicator
|
||||
- 其他支持 StatusNotifierItem 规范的环境
|
||||
|
||||
### Server 模式(前后端分离)
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun install
|
||||
bun dev
|
||||
make server-run
|
||||
```
|
||||
|
||||
前端开发服务器将在 `http://localhost:5173` 启动。
|
||||
`make server-run` 会并行启动:
|
||||
- 后端服务:`http://localhost:9826`
|
||||
- 前端开发服务器:`http://localhost:5173`
|
||||
|
||||
前端请求会继续通过 Vite proxy 转发到后端。后端首次启动会自动:
|
||||
- 创建配置文件 `~/.nex/config.yaml`
|
||||
- 初始化数据库 `~/.nex/config.db`
|
||||
- 运行数据库迁移
|
||||
- 创建日志目录 `~/.nex/log/`
|
||||
|
||||
**构建 server 模式产物**:
|
||||
|
||||
```bash
|
||||
make server-build
|
||||
```
|
||||
|
||||
## API 接口
|
||||
|
||||
### 代理接口(对外部应用)
|
||||
|
||||
- `POST /v1/chat/completions` - OpenAI Chat Completions API
|
||||
- `POST /v1/messages` - Anthropic Messages API
|
||||
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写并保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
|
||||
|
||||
**OpenAI 协议**(`protocol=openai`):
|
||||
- `POST /openai/v1/chat/completions` - 对话补全
|
||||
- `GET /openai/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
- `POST /openai/v1/embeddings` - 嵌入
|
||||
- `POST /openai/v1/rerank` - 重排序
|
||||
|
||||
**Anthropic 协议**(`protocol=anthropic`):
|
||||
- `POST /anthropic/v1/messages` - 消息对话
|
||||
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
|
||||
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions`、`/v1/models`、`/v1/embeddings`、`/v1/rerank`,并在构建上游 URL 时去掉 `/v1`;Anthropic adapter 接收 `/v1/messages`、`/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`),Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
|
||||
|
||||
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON`、`MODEL_NOT_FOUND`、`CONVERSION_FAILED`、`UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
|
||||
|
||||
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
|
||||
|
||||
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
|
||||
|
||||
### 管理接口(对前端)
|
||||
|
||||
- `GET/POST/PUT/DELETE /api/providers` - 供应商管理
|
||||
- `GET/POST/PUT/DELETE /api/models` - 模型管理
|
||||
- `GET /api/stats` - 统计查询
|
||||
#### 供应商管理
|
||||
- `GET /api/providers` - 列出所有供应商
|
||||
- `POST /api/providers` - 创建供应商(`id` 仅限字母、数字、下划线,长度 1-64)
|
||||
- `GET /api/providers/:id` - 获取供应商
|
||||
- `PUT /api/providers/:id` - 更新供应商(`id` 不可修改)
|
||||
- `DELETE /api/providers/:id` - 删除供应商
|
||||
|
||||
## 配置存储
|
||||
#### 模型管理
|
||||
- `GET /api/models` - 列出模型(支持 `?provider_id=xxx` 过滤)
|
||||
- `POST /api/models` - 创建模型(`id` 由系统自动生成 UUID,`provider_id` + `model_name` 联合唯一)
|
||||
- `GET /api/models/:id` - 获取模型(响应含 `unified_id` 字段,格式 `provider_id/model_name`)
|
||||
- `PUT /api/models/:id` - 更新模型(不可修改 `id`)
|
||||
- `DELETE /api/models/:id` - 删除模型
|
||||
|
||||
配置数据存储在用户目录:`~/.nex/config.db`
|
||||
#### 统计查询
|
||||
- `GET /api/stats` - 查询统计
|
||||
- `GET /api/stats/aggregate` - 聚合统计
|
||||
|
||||
查询参数支持:`provider_id`、`model_name`、`start_date`、`end_date`、`group_by`
|
||||
|
||||
## 配置
|
||||
|
||||
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
port: 9826
|
||||
read_timeout: 30s
|
||||
write_timeout: 30s
|
||||
|
||||
database:
|
||||
driver: sqlite # sqlite 或 mysql
|
||||
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||
# --- MySQL 配置(driver=mysql 时生效)---
|
||||
# host: localhost
|
||||
# port: 3306
|
||||
# user: nex
|
||||
# password: ""
|
||||
# dbname: nex
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
conn_max_lifetime: 1h
|
||||
|
||||
log:
|
||||
level: info
|
||||
path: ~/.nex/log
|
||||
max_size: 100 # MB
|
||||
max_backups: 10
|
||||
max_age: 30 # 天
|
||||
compress: true
|
||||
```
|
||||
|
||||
### 环境变量
|
||||
|
||||
所有配置项支持环境变量,使用 `NEX_` 前缀:
|
||||
|
||||
```bash
|
||||
export NEX_SERVER_PORT=9000
|
||||
export NEX_DATABASE_PATH=/data/nex.db
|
||||
export NEX_LOG_LEVEL=debug
|
||||
|
||||
# MySQL 模式
|
||||
export NEX_DATABASE_DRIVER=mysql
|
||||
export NEX_DATABASE_HOST=db.example.com
|
||||
export NEX_DATABASE_PORT=3306
|
||||
export NEX_DATABASE_USER=nex
|
||||
export NEX_DATABASE_PASSWORD=secret
|
||||
export NEX_DATABASE_DBNAME=nex
|
||||
```
|
||||
|
||||
命名规则:配置路径转大写 + 下划线(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||
|
||||
### CLI 参数
|
||||
|
||||
```bash
|
||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||
```
|
||||
|
||||
命名规则:配置路径转 kebab-case(如 `server.port` → `--server-port`)。
|
||||
|
||||
### 数据文件
|
||||
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 全局默认测试(不含 MySQL 和前端 E2E)
|
||||
make test
|
||||
|
||||
# 产品级测试
|
||||
make server-test
|
||||
make desktop-test
|
||||
```
|
||||
|
||||
backend 分类测试、MySQL 专项测试和前端 E2E 测试请分别查看 `backend/README.md` 与 `frontend/README.md`。
|
||||
|
||||
## 开发
|
||||
|
||||
```bash
|
||||
# 首次克隆后安装 Git hooks
|
||||
lefthook install
|
||||
|
||||
# 全局命令
|
||||
make lint # 前后端共享检查
|
||||
make test # 默认全量测试(不含 MySQL/E2E)
|
||||
make clean # 清理所有构建产物和测试报告
|
||||
|
||||
# server 模式
|
||||
make server-run # 并行启动后端和前端开发服务
|
||||
make server-build # 构建 backend/bin/server 和 frontend/dist
|
||||
make server-lint # server 模式检查
|
||||
make server-test # server 模式测试
|
||||
make server-clean # 清理 server 模式产物
|
||||
|
||||
# desktop 模式
|
||||
make desktop-build-mac # 构建 macOS 桌面应用
|
||||
make desktop-build-win # 构建 Windows 桌面应用
|
||||
make desktop-build-linux # 构建 Linux 桌面应用
|
||||
make desktop-lint # desktop 模式检查
|
||||
make desktop-test # desktop 专属测试
|
||||
make desktop-clean # 清理 desktop 产物
|
||||
```
|
||||
|
||||
## 版本与发布
|
||||
|
||||
### 统一版本源
|
||||
|
||||
- 仓库根目录 `VERSION` 是全仓唯一版本源,格式固定为 `x.y.z`
|
||||
- `frontend/package.json` 和前端 `.env.*` 中的 `VITE_APP_VERSION` 由仓库工具同步,不能手工漂移
|
||||
|
||||
### 本地版本演进
|
||||
|
||||
1. 手工修改根目录 `VERSION` 为新的 `x.y.z`
|
||||
2. 同步镜像文件:
|
||||
|
||||
```bash
|
||||
make version-sync
|
||||
```
|
||||
|
||||
3. 校验版本一致性:
|
||||
|
||||
```bash
|
||||
make version-check
|
||||
```
|
||||
|
||||
4. 提交版本变更后,创建发布 tag:
|
||||
|
||||
```bash
|
||||
git tag -a vX.Y.Z -m "Release vX.Y.Z"
|
||||
git push origin main
|
||||
git push origin vX.Y.Z
|
||||
```
|
||||
|
||||
### 本地生成发布资产
|
||||
|
||||
```bash
|
||||
# Linux: server + desktop
|
||||
make release-assets-linux
|
||||
|
||||
# Windows: server + desktop(需在 Windows 环境执行)
|
||||
make release-assets-windows
|
||||
|
||||
# macOS: darwin-amd64 server、darwin-arm64 server、desktop universal
|
||||
make release-assets-macos
|
||||
```
|
||||
|
||||
生成的版本化发布资产位于 `build/release/`。
|
||||
|
||||
### GitHub Draft Release
|
||||
|
||||
- 推送 `vX.Y.Z` tag 后,`.github/workflows/release.yml` 会自动执行发布流水线
|
||||
- 流水线会先校验 tag 与 `VERSION` 一致,再构建以下资产并上传到 GitHub Draft Release:
|
||||
- Linux server
|
||||
- Windows server
|
||||
- darwin-amd64 server
|
||||
- darwin-arm64 server
|
||||
- Linux desktop
|
||||
- Windows desktop
|
||||
- macOS desktop universal
|
||||
- Release 默认以 Draft 形式创建,需人工检查后再公开发布
|
||||
|
||||
## 开发规范
|
||||
|
||||
详见各子项目的 README.md:
|
||||
- [后端 README](backend/README.md)
|
||||
- [前端 README](frontend/README.md)
|
||||
- [后端 README](backend/README.md) - 分层架构、依赖注入、数据库迁移
|
||||
- [前端 README](frontend/README.md) - TypeScript strict、SCSS Modules、测试策略
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
Apache License 2.0
|
||||
|
||||
BIN
assets/icon.icns
LFS
Normal file
BIN
assets/icon.icns
LFS
Normal file
Binary file not shown.
BIN
assets/icon.ico
LFS
Normal file
BIN
assets/icon.ico
LFS
Normal file
Binary file not shown.
BIN
assets/icon.png
LFS
Normal file
BIN
assets/icon.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
Binary file not shown.
91
backend/.golangci.yml
Normal file
91
backend/.golangci.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- forbidigo
|
||||
- errorlint
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- revive
|
||||
- gocritic
|
||||
- gosec
|
||||
- bodyclose
|
||||
- noctx
|
||||
- nilerr
|
||||
- goimports
|
||||
- gocyclo
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-blank: true
|
||||
check-type-assertions: true
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
forbidigo:
|
||||
analyze-types: true
|
||||
forbid:
|
||||
- p: '^fmt\.Print.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^fmt\.Fprint.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
|
||||
msg: 使用 zap logger,不要使用标准库 log
|
||||
- p: '^zap\.L$'
|
||||
msg: 通过依赖注入传递 *zap.Logger,不要使用全局 logger
|
||||
- p: '^zap\.S$'
|
||||
msg: 不使用 Sugar logger
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
- name: var-naming
|
||||
- name: indent-error-flow
|
||||
- name: error-strings
|
||||
- name: error-return
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: unexported-return
|
||||
goimports:
|
||||
local-prefixes: nex/backend
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
|
||||
issues:
|
||||
exclude-dirs:
|
||||
- tests/mocks
|
||||
exclude-generated: true
|
||||
exclude-rules:
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- forbidigo
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- errcheck
|
||||
source: '(^\s*_\s*=|,\s*_)'
|
||||
- path: 'tests/integration/e2e_conversion_test\.go'
|
||||
linters:
|
||||
- errcheck
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- revive
|
||||
text: '^exported:'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gosec
|
||||
text: 'G(101|401|501)'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gocyclo
|
||||
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
|
||||
- linters:
|
||||
- revive
|
||||
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
|
||||
- path: 'internal/conversion/.*\.go'
|
||||
linters:
|
||||
- gocyclo
|
||||
- gocritic
|
||||
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
|
||||
linters:
|
||||
- gocyclo
|
||||
@@ -1,45 +1,97 @@
|
||||
.PHONY: build run test test-coverage clean migrate-up migrate-down migrate-status migrate-create lint
|
||||
.PHONY: \
|
||||
build run \
|
||||
test test-unit test-integration test-coverage \
|
||||
lint clean \
|
||||
migrate-up migrate-down migrate-status migrate-create \
|
||||
mysql-up mysql-down mysql-test mysql-test-quick
|
||||
|
||||
VERSION := $(shell go run ./cmd/versionctl print)
|
||||
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||
|
||||
DB_DRIVER ?= sqlite3
|
||||
DB_DSN ?= $(HOME)/.nex/config.db
|
||||
|
||||
ifeq ($(DB_DRIVER),mysql)
|
||||
GOOSE_DIR := migrations/mysql
|
||||
GOOSE_DRIVER := mysql
|
||||
else ifeq ($(DB_DRIVER),sqlite3)
|
||||
GOOSE_DIR := migrations/sqlite
|
||||
GOOSE_DRIVER := sqlite3
|
||||
else
|
||||
$(error unsupported DB_DRIVER '$(DB_DRIVER)', use sqlite3 or mysql)
|
||||
endif
|
||||
|
||||
# 构建
|
||||
build:
|
||||
go build -o bin/server ./cmd/server
|
||||
go build -ldflags "$(GO_LDFLAGS)" -o bin/server ./cmd/server
|
||||
|
||||
# 运行
|
||||
run:
|
||||
go run ./cmd/server
|
||||
go run -ldflags "$(GO_LDFLAGS)" ./cmd/server
|
||||
|
||||
# 测试
|
||||
test:
|
||||
go test ./... -v
|
||||
go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
|
||||
|
||||
test-unit:
|
||||
go test ./internal/... ./pkg/... -v
|
||||
|
||||
test-integration:
|
||||
go test ./tests/... -v
|
||||
|
||||
# 测试覆盖率
|
||||
test-coverage:
|
||||
go test ./... -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: coverage.html"
|
||||
@printf 'Coverage report generated: backend/coverage.html\n'
|
||||
|
||||
lint:
|
||||
go tool golangci-lint run ./...
|
||||
|
||||
# 清理
|
||||
clean:
|
||||
rm -rf bin/ coverage.out coverage.html
|
||||
|
||||
# 数据库迁移
|
||||
migrate-up:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) up
|
||||
@printf 'Running database migration up...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" up
|
||||
|
||||
migrate-down:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) down
|
||||
@printf 'Running database migration down...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" down
|
||||
|
||||
migrate-status:
|
||||
goose -dir migrations sqlite3 $(DB_PATH) status
|
||||
@printf 'Checking database migration status...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" status
|
||||
|
||||
migrate-create:
|
||||
@read -p "Migration name: " name; \
|
||||
goose -dir migrations create $$name sql
|
||||
@printf 'Migration name: '; \
|
||||
read name; \
|
||||
goose -dir migrations/sqlite create $$name sql; \
|
||||
goose -dir migrations/mysql create $$name sql
|
||||
|
||||
# 代码检查
|
||||
lint:
|
||||
golangci-lint run ./...
|
||||
mysql-up:
|
||||
@printf 'Starting MySQL test container...\n'
|
||||
cd tests/mysql && docker-compose up -d
|
||||
@printf 'Waiting for MySQL to be ready...\n'
|
||||
@for i in $$(seq 1 30); do \
|
||||
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
|
||||
printf 'MySQL is ready\n'; \
|
||||
exit 0; \
|
||||
fi; \
|
||||
printf 'Waiting... (%s/30)\n' $$i; \
|
||||
sleep 1; \
|
||||
done; \
|
||||
printf 'MySQL failed to start\n'; \
|
||||
exit 1
|
||||
|
||||
# 安装依赖
|
||||
deps:
|
||||
go mod tidy
|
||||
mysql-down:
|
||||
@printf 'Stopping MySQL test container...\n'
|
||||
cd tests/mysql && docker-compose down -v
|
||||
|
||||
mysql-test:
|
||||
@set -e; \
|
||||
$(MAKE) mysql-up; \
|
||||
trap '$(MAKE) mysql-down' EXIT; \
|
||||
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
|
||||
mysql-test-quick:
|
||||
@printf 'Running MySQL tests without container management...\n'
|
||||
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
|
||||
@@ -4,25 +4,75 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持 OpenAI 协议(`/v1/chat/completions`)
|
||||
- 支持 Anthropic 协议(`/v1/messages`)
|
||||
- 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`)
|
||||
- 支持 Anthropic 协议(`/anthropic/v1/...`)
|
||||
- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic)
|
||||
- 同协议透传(跳过 Canonical 全量转换,保持协议语义)
|
||||
- 支持流式响应(SSE)
|
||||
- 支持 Function Calling / Tools
|
||||
- 支持 Thinking / Reasoning
|
||||
- 支持扩展层接口(Models、Embeddings、Rerank)
|
||||
- 多供应商配置和路由
|
||||
- 用量统计
|
||||
- 结构化日志(zap + lumberjack)
|
||||
- 结构化日志(zap + lumberjack + 模块标识)
|
||||
- YAML 配置管理
|
||||
- 请求验证
|
||||
- 中间件支持(请求 ID、日志、恢复、CORS)
|
||||
|
||||
## 日志规范
|
||||
|
||||
### 模块标识
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger:
|
||||
|
||||
```go
|
||||
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
输出格式:
|
||||
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
|
||||
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
|
||||
|
||||
### 模块命名规范
|
||||
|
||||
| 模块 | 命名 |
|
||||
|------|------|
|
||||
| ProxyHandler | `handler.proxy` |
|
||||
| ProviderHandler | `handler.provider` |
|
||||
| Provider Client | `provider.client` |
|
||||
| ConversionEngine | `conversion.engine` |
|
||||
| RoutingCache | `service.routing_cache` |
|
||||
| StatsBuffer | `service.stats_buffer` |
|
||||
| Database | `database` |
|
||||
|
||||
### 标准字段
|
||||
|
||||
使用 `pkg/logger/field.go` 中定义的字段构造函数:
|
||||
|
||||
```go
|
||||
logger.Info("请求开始",
|
||||
pkglogger.Method("POST"),
|
||||
pkglogger.Path("/v1/chat"),
|
||||
pkglogger.RequestID("xxx"),
|
||||
)
|
||||
```
|
||||
|
||||
### GORM 日志
|
||||
|
||||
GORM 日志自动桥接到 zap,SQL 查询映射到 Debug 级别。
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **语言**: Go 1.26+
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite
|
||||
- **数据库**: SQLite / MySQL
|
||||
- **日志**: zap + lumberjack
|
||||
- **配置**: gopkg.in/yaml.v3
|
||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
@@ -35,7 +85,7 @@ backend/
|
||||
│ └── main.go # 主程序入口(依赖注入)
|
||||
├── internal/
|
||||
│ ├── config/ # 配置管理
|
||||
│ │ ├── config.go # 配置加载/保存/验证
|
||||
│ │ ├── config.go # Viper 多层配置加载/验证
|
||||
│ │ └── models.go # GORM 数据模型
|
||||
│ ├── domain/ # 领域模型
|
||||
│ │ ├── provider.go
|
||||
@@ -43,24 +93,41 @@ backend/
|
||||
│ │ ├── stats.go
|
||||
│ │ └── route.go
|
||||
│ ├── handler/ # HTTP 处理器
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ ├── middleware/ # 中间件
|
||||
│ │ │ ├── request_id.go
|
||||
│ │ │ ├── logging.go
|
||||
│ │ │ ├── recovery.go
|
||||
│ │ │ └── cors.go
|
||||
│ │ ├── openai_handler.go
|
||||
│ │ ├── anthropic_handler.go
|
||||
│ │ ├── proxy_handler.go # 统一代理处理器
|
||||
│ │ ├── provider_handler.go
|
||||
│ │ ├── model_handler.go
|
||||
│ │ └── stats_handler.go
|
||||
│ ├── protocol/ # 协议适配器
|
||||
│ │ ├── openai/
|
||||
│ │ │ ├── types.go # 请求/响应类型 + 验证
|
||||
│ │ │ └── adapter.go # OpenAI 协议适配
|
||||
│ │ └── anthropic/
|
||||
│ │ ├── types.go # 请求/响应类型 + 验证
|
||||
│ │ ├── converter.go # 协议转换
|
||||
│ │ └── stream_converter.go # 流式转换
|
||||
│ ├── conversion/ # 协议转换引擎
|
||||
│ │ ├── canonical/ # Canonical Model
|
||||
│ │ │ ├── types.go # 核心请求/响应类型
|
||||
│ │ │ ├── stream.go # 流式事件类型
|
||||
│ │ │ └── extended.go # 扩展层 Models
|
||||
│ │ ├── openai/ # OpenAI 协议适配器
|
||||
│ │ │ ├── types.go
|
||||
│ │ │ ├── adapter.go
|
||||
│ │ │ ├── decoder.go
|
||||
│ │ │ ├── encoder.go
|
||||
│ │ │ ├── stream_decoder.go
|
||||
│ │ │ └── stream_encoder.go
|
||||
│ │ ├── anthropic/ # Anthropic 协议适配器
|
||||
│ │ │ ├── types.go
|
||||
│ │ │ ├── adapter.go
|
||||
│ │ │ ├── decoder.go
|
||||
│ │ │ ├── encoder.go
|
||||
│ │ │ ├── stream_decoder.go
|
||||
│ │ │ └── stream_encoder.go
|
||||
│ │ ├── adapter.go # ProtocolAdapter 接口 + Registry
|
||||
│ │ ├── stream.go # StreamDecoder/Encoder/Converter
|
||||
│ │ ├── middleware.go # Middleware 接口和 Chain
|
||||
│ │ ├── engine.go # ConversionEngine 门面
|
||||
│ │ ├── errors.go # ConversionError
|
||||
│ │ ├── interface.go # InterfaceType 枚举
|
||||
│ │ └── provider.go # TargetProvider
|
||||
│ ├── provider/ # 供应商客户端
|
||||
│ │ └── client.go
|
||||
│ ├── repository/ # 数据访问层
|
||||
@@ -84,19 +151,26 @@ backend/
|
||||
│ │ ├── errors.go
|
||||
│ │ └── wrap.go
|
||||
│ ├── logger/ # 日志系统
|
||||
│ │ ├── logger.go
|
||||
│ │ ├── rotate.go
|
||||
│ │ └── context.go
|
||||
│ │ ├── logger.go # 核心初始化
|
||||
│ │ ├── field.go # 标准字段定义
|
||||
│ │ ├── module.go # 模块日志器
|
||||
│ │ ├── context.go # Context 辅助函数
|
||||
│ │ ├── gorm.go # GORM 适配器
|
||||
│ │ ├── minimal.go # 最小化 logger
|
||||
│ │ └── rotate.go # 日志轮转
|
||||
│ ├── modelid/ # 统一模型 ID 工具包
|
||||
│ │ ├── model_id.go
|
||||
│ │ └── model_id_test.go
|
||||
│ └── validator/ # 验证器
|
||||
│ └── validator.go
|
||||
├── migrations/ # 数据库迁移
|
||||
│ ├── 001_initial_schema.sql
|
||||
│ └── 002_add_indexes.sql
|
||||
├── tests/ # 测试
|
||||
│ ├── helpers.go
|
||||
│ ├── integration/
|
||||
│ ├── unit/
|
||||
│ └── testdata/
|
||||
│ └── 20260421000001_initial_schema.sql
|
||||
├── tests/ # 集成测试
|
||||
│ ├── helpers.go # 测试辅助函数
|
||||
│ ├── config/ # 测试配置
|
||||
│ ├── integration/ # 集成测试
|
||||
│ │ └── e2e_conversion_test.go # E2E 协议转换测试
|
||||
│ └── mocks/ # Mock 实现
|
||||
├── Makefile
|
||||
├── go.mod
|
||||
└── README.md
|
||||
@@ -112,6 +186,136 @@ handler(HTTP 请求处理)
|
||||
→ repository(数据访问)
|
||||
```
|
||||
|
||||
代理请求通过 ConversionEngine 进行协议转换:
|
||||
|
||||
```
|
||||
Client Request (clientProtocol)
|
||||
→ ProxyHandler 路由到上游 provider
|
||||
→ ConversionEngine 请求转换 (clientProtocol → providerProtocol)
|
||||
→ ProviderClient 发送请求
|
||||
→ ConversionEngine 响应转换 (providerProtocol → clientProtocol)
|
||||
→ Client Response
|
||||
```
|
||||
|
||||
同协议时自动透传,跳过序列化开销。
|
||||
|
||||
## 协议转换架构
|
||||
|
||||
### Canonical Model 中间表示
|
||||
|
||||
所有协议转换都经过 Canonical Model 中间表示层,实现 Hub-and-Spoke 架构:
|
||||
|
||||
```
|
||||
OpenAI Request → Canonical Request → Anthropic Request
|
||||
(中间表示)
|
||||
OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
```
|
||||
|
||||
**CanonicalRequest 核心字段**:
|
||||
- `Model` - 统一模型 ID
|
||||
- `Messages` - 消息列表(支持 text、tool_use、tool_result、thinking 类型)
|
||||
- `Tools` - 工具定义
|
||||
- `Thinking` - 推理配置(`budget_tokens`、`effort`)
|
||||
- `Parameters` - 通用参数(`max_tokens`、`temperature`、`top_p` 等)
|
||||
|
||||
### Smart Passthrough 机制
|
||||
|
||||
同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换:
|
||||
|
||||
```
|
||||
1. 检测 clientProtocol == providerProtocol
|
||||
2. 仅改写请求体中的 model 字段:unified_id → upstream_model_name
|
||||
3. 直接转发请求到上游
|
||||
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
||||
```
|
||||
|
||||
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
|
||||
|
||||
### 流式转换器层次
|
||||
|
||||
```
|
||||
StreamConverter (接口)
|
||||
├── PassthroughStreamConverter # 直接透传,无任何处理
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
|
||||
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
||||
```
|
||||
|
||||
### InterfaceType 枚举
|
||||
|
||||
| 类型 | 说明 |
|
||||
|------|------|
|
||||
| `CHAT` | 对话补全(chat/completions、messages) |
|
||||
| `MODELS` | 模型列表 |
|
||||
| `MODEL_INFO` | 模型详情 |
|
||||
| `EMBEDDINGS` | 嵌入接口 |
|
||||
| `RERANK` | 重排序接口 |
|
||||
| `PASSTHROUGH` | 未知接口,直接透传 |
|
||||
|
||||
## 协议适配器特性
|
||||
|
||||
### OpenAI 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `reasoning_effort` - 映射到 Canonical Thinking 配置(`none` → 禁用,其他 → `effort`)
|
||||
- `reasoning_content` - 非标准字段,映射到 Canonical thinking 块
|
||||
- `max_completion_tokens` - 新字段,优先于 `max_tokens`
|
||||
- `refusal` - 非标准字段,作为 text 块处理
|
||||
|
||||
**废弃字段兼容**:
|
||||
- `functions` / `function_call` - 自动转换为 `tools` / `tool_choice`
|
||||
|
||||
**消息处理**:
|
||||
- 合并连续同角色消息(Anthropic 不支持连续同角色)
|
||||
- 工具选择映射:`any` → `required`
|
||||
|
||||
### Anthropic 适配器
|
||||
|
||||
**特有字段支持**:
|
||||
- `thinking` - 推理配置(`type: enabled`、`budget_tokens`、`effort`)
|
||||
- `output_config` - 结构化输出配置
|
||||
- `disable_parallel_tool_use` - 禁用并行工具调用
|
||||
- `container` - 工具容器字段
|
||||
|
||||
**不支持的功能**:
|
||||
- Embeddings 接口(返回 `INTERFACE_NOT_SUPPORTED` 错误)
|
||||
|
||||
### 跨协议转换注意事项
|
||||
|
||||
| 源协议 | 目标协议 | 转换说明 |
|
||||
|--------|----------|----------|
|
||||
| OpenAI | Anthropic | `reasoning_effort` → `thinking`,消息角色合并 |
|
||||
| Anthropic | OpenAI | `thinking` 块 → `reasoning_content`,工具选择转换 |
|
||||
|
||||
## 错误码
|
||||
|
||||
### ConversionError 错误码
|
||||
|
||||
| 错误码 | 说明 |
|
||||
|--------|------|
|
||||
| `INVALID_INPUT` | 输入数据无效 |
|
||||
| `MISSING_REQUIRED_FIELD` | 缺少必填字段 |
|
||||
| `INCOMPATIBLE_FEATURE` | 功能不兼容(如跨协议不支持某特性) |
|
||||
| `FIELD_MAPPING_FAILURE` | 字段映射失败 |
|
||||
| `TOOL_CALL_PARSE_ERROR` | 工具调用解析错误 |
|
||||
| `JSON_PARSE_ERROR` | JSON 解析错误 |
|
||||
| `STREAM_STATE_ERROR` | 流式状态错误 |
|
||||
| `UTF8_DECODE_ERROR` | UTF-8 解码错误(流式 chunk 截断) |
|
||||
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
||||
| `ENCODING_FAILURE` | 编码失败 |
|
||||
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
||||
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
|
||||
|
||||
### AppError 预定义错误
|
||||
|
||||
| 错误 | HTTP 状态码 | 说明 |
|
||||
|------|-------------|------|
|
||||
| `ErrModelNotFound` | 404 | 模型未找到 |
|
||||
| `ErrModelDisabled` | 404 | 模型已禁用 |
|
||||
| `ErrProviderNotFound` | 404 | 供应商未找到 |
|
||||
| `ErrInvalidProviderID` | 400 | 供应商 ID 格式无效 |
|
||||
| `ErrDuplicateModel` | 409 | 同一供应商下模型名称重复 |
|
||||
| `ErrImmutableField` | 400 | 不可修改字段(如供应商 ID) |
|
||||
|
||||
## 运行方式
|
||||
|
||||
### 安装依赖
|
||||
@@ -130,6 +334,10 @@ go run cmd/server/main.go
|
||||
|
||||
## 配置
|
||||
|
||||
配置支持多种方式:配置文件、环境变量、命令行参数,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件位于 `~/.nex/config.yaml`,首次启动自动生成。
|
||||
|
||||
```yaml
|
||||
@@ -139,7 +347,14 @@ server:
|
||||
write_timeout: 30s
|
||||
|
||||
database:
|
||||
path: ~/.nex/config.db
|
||||
driver: sqlite # sqlite 或 mysql
|
||||
path: ~/.nex/config.db # SQLite 数据库文件路径
|
||||
# --- MySQL 配置(driver=mysql 时生效)---
|
||||
# host: localhost
|
||||
# port: 3306
|
||||
# user: nex
|
||||
# password: ""
|
||||
# dbname: nex
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
conn_max_lifetime: 1h
|
||||
@@ -153,28 +368,105 @@ log:
|
||||
compress: true
|
||||
```
|
||||
|
||||
### 环境变量
|
||||
|
||||
所有配置项都支持环境变量,使用 `NEX_` 前缀:
|
||||
|
||||
```bash
|
||||
export NEX_SERVER_PORT=9000
|
||||
export NEX_DATABASE_PATH=/data/nex.db
|
||||
export NEX_LOG_LEVEL=debug
|
||||
|
||||
# MySQL 模式
|
||||
export NEX_DATABASE_DRIVER=mysql
|
||||
export NEX_DATABASE_HOST=db.example.com
|
||||
export NEX_DATABASE_PORT=3306
|
||||
export NEX_DATABASE_USER=nex
|
||||
export NEX_DATABASE_PASSWORD=secret
|
||||
export NEX_DATABASE_DBNAME=nex
|
||||
```
|
||||
|
||||
命名规则:配置路径转大写 + 下划线 + `NEX_` 前缀(如 `server.port` → `NEX_SERVER_PORT`)。
|
||||
|
||||
### 命令行参数
|
||||
|
||||
```bash
|
||||
./server --server-port 9000 --log-level debug --database-path /tmp/test.db
|
||||
```
|
||||
|
||||
命名规则:配置路径转 kebab-case + `--` 前缀(如 `server.port` → `--server-port`)。
|
||||
|
||||
完整参数列表:
|
||||
|
||||
```
|
||||
服务器: --server-port, --server-read-timeout, --server-write-timeout
|
||||
数据库: --database-driver, --database-path, --database-host, --database-port, --database-user, --database-password, --database-dbname, --database-max-idle-conns, --database-max-open-conns, --database-conn-max-lifetime
|
||||
日志: --log-level, --log-path, --log-max-size, --log-max-backups, --log-max-age, --log-compress
|
||||
通用: --config (指定配置文件路径)
|
||||
```
|
||||
|
||||
### 使用示例
|
||||
|
||||
```bash
|
||||
# 默认配置
|
||||
./server
|
||||
|
||||
# 临时修改端口
|
||||
./server --server-port 9000
|
||||
|
||||
# 测试场景
|
||||
./server --database-path /tmp/test.db --log-level debug
|
||||
|
||||
# Docker 部署
|
||||
docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
|
||||
|
||||
# MySQL 模式
|
||||
./server --database-driver mysql --database-host db.example.com --database-user nex --database-password secret --database-dbname nex
|
||||
|
||||
# 自定义配置文件
|
||||
./server --config /path/to/custom.yaml
|
||||
```
|
||||
|
||||
数据文件:
|
||||
- `~/.nex/config.yaml` - 配置文件
|
||||
- `~/.nex/config.db` - SQLite 数据库
|
||||
- `~/.nex/config.db` - SQLite 数据库(MySQL 模式下不使用本地数据库文件)
|
||||
- `~/.nex/log/` - 日志目录
|
||||
|
||||
**MySQL 连接说明**:MySQL 连接使用 DSN 格式: `user:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=true&loc=Local`,最低支持 MySQL 8.0+。
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
# 运行 backend 默认测试
|
||||
make test
|
||||
|
||||
# 分类测试
|
||||
make test-unit
|
||||
make test-integration
|
||||
|
||||
# 生成覆盖率报告
|
||||
make test-coverage
|
||||
|
||||
# MySQL 专项测试
|
||||
make mysql-up
|
||||
make mysql-down
|
||||
make mysql-test
|
||||
make mysql-test-quick
|
||||
```
|
||||
|
||||
## 数据库迁移
|
||||
|
||||
```bash
|
||||
# 使用 Makefile
|
||||
make migrate-up DB_PATH=~/.nex/config.db
|
||||
make migrate-down DB_PATH=~/.nex/config.db
|
||||
make migrate-status DB_PATH=~/.nex/config.db
|
||||
make migrate-up DB_DSN=~/.nex/config.db
|
||||
make migrate-down DB_DSN=~/.nex/config.db
|
||||
make migrate-status DB_DSN=~/.nex/config.db
|
||||
|
||||
# 创建新迁移
|
||||
make migrate-create
|
||||
|
||||
# MySQL 迁移
|
||||
make migrate-up DB_DRIVER=mysql DB_DSN='user:pass@tcp(localhost:3306)/nex'
|
||||
|
||||
# 或直接使用 goose
|
||||
goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
@@ -184,41 +476,37 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
|
||||
### 代理接口
|
||||
|
||||
#### OpenAI Chat Completions
|
||||
使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath,由对应 adapter 识别和组合上游 URL。
|
||||
|
||||
#### OpenAI 协议
|
||||
|
||||
```
|
||||
POST /v1/chat/completions
|
||||
POST /openai/v1/chat/completions
|
||||
GET /openai/v1/models
|
||||
POST /openai/v1/embeddings
|
||||
POST /openai/v1/rerank
|
||||
```
|
||||
|
||||
请求示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
#### Anthropic Messages
|
||||
#### Anthropic 协议
|
||||
|
||||
```
|
||||
POST /v1/messages
|
||||
POST /anthropic/v1/messages
|
||||
GET /anthropic/v1/models
|
||||
```
|
||||
|
||||
请求示例:
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough,跳过 Canonical 全量转换。
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello"}]}
|
||||
]
|
||||
}
|
||||
```
|
||||
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
||||
|
||||
**base_url 约定**:
|
||||
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions` 时,OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`。
|
||||
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`。
|
||||
|
||||
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
|
||||
|
||||
**流式透传边界**:同协议无响应 model 改写时 raw passthrough,保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON,仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
|
||||
|
||||
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`。
|
||||
|
||||
### 管理接口
|
||||
|
||||
@@ -230,22 +518,25 @@ POST /v1/messages
|
||||
- `PUT /api/providers/:id` - 更新供应商
|
||||
- `DELETE /api/providers/:id` - 删除供应商
|
||||
|
||||
创建供应商示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "openai",
|
||||
"name": "OpenAI",
|
||||
"api_key": "sk-...",
|
||||
"base_url": "https://api.openai.com/v1"
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"protocol": "openai"
|
||||
}
|
||||
```
|
||||
|
||||
**重要说明:**
|
||||
- `base_url` 应配置到 API 版本路径,不包含具体端点
|
||||
- OpenAI: `https://api.openai.com/v1`
|
||||
- GLM: `https://open.bigmodel.cn/api/paas/v4`
|
||||
- 其他 OpenAI 兼容供应商根据其文档配置版本路径
|
||||
**Protocol 字段**:标识上游供应商使用的协议类型,可选值 `"openai"`(默认)、`"anthropic"`。
|
||||
|
||||
**base_url 说明**:
|
||||
- OpenAI 协议:配置到 API 版本路径,如 `https://api.openai.com/v1`、`https://open.bigmodel.cn/api/paas/v4`
|
||||
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
|
||||
|
||||
**对外 URL 格式**:
|
||||
- OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions`、`/openai/v1/models`、`/openai/v1/embeddings`
|
||||
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages`、`/anthropic/v1/models`
|
||||
|
||||
#### 模型管理
|
||||
|
||||
@@ -255,43 +546,99 @@ POST /v1/messages
|
||||
- `PUT /api/models/:id` - 更新模型
|
||||
- `DELETE /api/models/:id` - 删除模型
|
||||
|
||||
创建模型示例:
|
||||
**创建请求**(id 由系统自动生成 UUID):
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "gpt-4",
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4"
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"unified_id": "openai/gpt-4",
|
||||
"enabled": true,
|
||||
"created_at": "2026-04-21T00:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
**统一模型 ID**:`unified_id` 字段为 `provider_id/model_name` 格式,用于代理请求的 `model` 参数。
|
||||
|
||||
#### 统计查询
|
||||
|
||||
- `GET /api/stats` - 查询统计
|
||||
- `GET /api/stats/aggregate` - 聚合统计
|
||||
|
||||
查询参数:
|
||||
查询参数:`provider_id`、`model_name`、`start_date`(YYYY-MM-DD)、`end_date`、`group_by`(provider/model/date)
|
||||
|
||||
- `provider_id` - 供应商 ID
|
||||
- `model_name` - 模型名称
|
||||
- `start_date` - 开始日期(YYYY-MM-DD)
|
||||
- `end_date` - 结束日期(YYYY-MM-DD)
|
||||
- `group_by` - 聚合维度(provider/model/date)
|
||||
#### 健康检查
|
||||
|
||||
- `GET /health` - 返回 `{"status": "ok"}`
|
||||
|
||||
## 开发
|
||||
|
||||
### 构建
|
||||
|
||||
```bash
|
||||
make build
|
||||
make build # 构建 backend/bin/server
|
||||
make run # 运行后端服务
|
||||
make lint # 代码检查
|
||||
make clean # 清理 backend 构建产物
|
||||
go mod tidy # 整理依赖
|
||||
go generate ./... # 刷新 mock 等生成代码
|
||||
```
|
||||
|
||||
### 代码检查
|
||||
环境要求:Go 1.26 或更高版本
|
||||
|
||||
```bash
|
||||
make lint
|
||||
## 公共库使用指南
|
||||
|
||||
### pkg/errors — 结构化错误
|
||||
|
||||
```go
|
||||
import (
|
||||
"errors"
|
||||
pkgErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
return pkgErrors.ErrRequestSend.WithCause(err)
|
||||
|
||||
var appErr *pkgErrors.AppError
|
||||
if errors.As(err, &appErr) {
|
||||
// appErr.Code, appErr.HTTPStatus, appErr.Message
|
||||
}
|
||||
```
|
||||
|
||||
### 环境要求
|
||||
### pkg/logger — 日志系统
|
||||
|
||||
- Go 1.26 或更高版本
|
||||
构造函数接受 `*zap.Logger` 参数,nil 时回退到 `zap.L()`:
|
||||
|
||||
```go
|
||||
func NewMyService(repo Repository, logger *zap.Logger) *MyService {
|
||||
if logger == nil {
|
||||
logger = zap.L()
|
||||
}
|
||||
return &MyService{repo: repo, logger: logger}
|
||||
}
|
||||
```
|
||||
|
||||
### pkg/validator — 请求验证
|
||||
|
||||
```go
|
||||
import "nex/backend/pkg/validator"
|
||||
|
||||
v := validator.Get()
|
||||
err := v.Validate(myStruct)
|
||||
```
|
||||
|
||||
## 编码规范
|
||||
|
||||
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
|
||||
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(lint 强约束:errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
|
||||
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
|
||||
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
|
||||
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片
|
||||
|
||||
25
backend/cmd/desktop/dialog_darwin.go
Normal file
25
backend/cmd/desktop/dialog_darwin.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
|
||||
escapeAppleScript(message), escapeAppleScript(title))
|
||||
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
|
||||
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func escapeAppleScript(s string) string {
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return s
|
||||
}
|
||||
67
backend/cmd/desktop/dialog_linux.go
Normal file
67
backend/cmd/desktop/dialog_linux.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type dialogToolType int
|
||||
|
||||
const (
|
||||
toolNone dialogToolType = iota
|
||||
toolZenity
|
||||
toolKdialog
|
||||
toolNotifySend
|
||||
toolXmessage
|
||||
)
|
||||
|
||||
var (
|
||||
dialogTool dialogToolType
|
||||
dialogToolOnce sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
dialogToolOnce.Do(detectDialogTool)
|
||||
}
|
||||
|
||||
func detectDialogTool() {
|
||||
tools := []struct {
|
||||
name string
|
||||
typ dialogToolType
|
||||
}{
|
||||
{"zenity", toolZenity},
|
||||
{"kdialog", toolKdialog},
|
||||
{"notify-send", toolNotifySend},
|
||||
{"xmessage", toolXmessage},
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool.name); err == nil {
|
||||
dialogTool = tool.typ
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dialogTool = toolNone
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
switch dialogTool {
|
||||
case toolZenity:
|
||||
exec.Command("zenity", "--error",
|
||||
fmt.Sprintf("--title=%s", title),
|
||||
fmt.Sprintf("--text=%s", message)).Run()
|
||||
case toolKdialog:
|
||||
exec.Command("kdialog", "--error", message, "--title", title).Run()
|
||||
case toolNotifySend:
|
||||
exec.Command("notify-send", "-u", "critical", title, message).Run()
|
||||
case toolXmessage:
|
||||
exec.Command("xmessage", "-center",
|
||||
fmt.Sprintf("%s: %s", title, message)).Run()
|
||||
default:
|
||||
dialogLogger().Error("无法显示错误对话框")
|
||||
}
|
||||
}
|
||||
15
backend/cmd/desktop/dialog_logger.go
Normal file
15
backend/cmd/desktop/dialog_logger.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func dialogLogger() *zap.Logger {
|
||||
if zapLogger != nil {
|
||||
return zapLogger
|
||||
}
|
||||
|
||||
return pkgLogger.NewMinimal()
|
||||
}
|
||||
62
backend/cmd/desktop/dialog_windows.go
Normal file
62
backend/cmd/desktop/dialog_windows.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
mbIconError = 0x10
|
||||
mbIconInformation = 0x40
|
||||
)
|
||||
|
||||
var (
|
||||
user32 = syscall.NewLazyDLL("user32.dll")
|
||||
procMessageBoxW = user32.NewProc("MessageBoxW")
|
||||
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
|
||||
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
|
||||
return ret, err
|
||||
}
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
if err := messageBox(title, message, mbIconError); err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func messageBox(title, message string, flags uint) error {
|
||||
titlePtr, err := syscall.UTF16PtrFromString(title)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messagePtr, err := syscall.UTF16PtrFromString(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret, callErr := callMessageBoxW(
|
||||
0,
|
||||
uintptr(unsafe.Pointer(messagePtr)),
|
||||
uintptr(unsafe.Pointer(titlePtr)),
|
||||
uintptr(flags),
|
||||
)
|
||||
if ret != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
|
||||
return callErr
|
||||
}
|
||||
|
||||
return fmt.Errorf("MessageBoxW 调用失败")
|
||||
}
|
||||
33
backend/cmd/desktop/icon_test.go
Normal file
33
backend/cmd/desktop/icon_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"nex/embedfs"
|
||||
)
|
||||
|
||||
func TestIconSelection_Windows(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("图标格式选择测试仅在 Windows 上运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.ico"); err != nil {
|
||||
t.Fatalf("Windows 应加载 .ico 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIconSelection_NonWindows(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("图标格式选择测试在非 Windows 平台运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.png"); err != nil {
|
||||
t.Fatalf("非 Windows 平台应加载 .png 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testIconLoad(path string) error {
|
||||
_, err := embedfs.Assets.ReadFile(path)
|
||||
return err
|
||||
}
|
||||
1
backend/cmd/desktop/icon_windows.rc
Normal file
1
backend/cmd/desktop/icon_windows.rc
Normal file
@@ -0,0 +1 @@
|
||||
1 ICON "../../../assets/icon.ico"
|
||||
407
backend/cmd/desktop/main.go
Normal file
407
backend/cmd/desktop/main.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/embedfs"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/database"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/buildinfo"
|
||||
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
server *http.Server
|
||||
zapLogger *zap.Logger
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := 9826
|
||||
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
||||
if err := singleLock.Lock(); err != nil {
|
||||
minimalLogger.Error("已有 Nex 实例运行")
|
||||
showError(appName, "已有 Nex 实例运行")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() {
|
||||
if err := singleLock.Unlock(); err != nil {
|
||||
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkPortAvailable(port); err != nil {
|
||||
minimalLogger.Error("端口不可用", zap.Error(err))
|
||||
showError(appName, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||
}
|
||||
|
||||
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
MaxBackups: cfg.Log.MaxBackups,
|
||||
MaxAge: cfg.Log.MaxAge,
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer database.Close(db)
|
||||
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
defer statsBuffer.Stop()
|
||||
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
r.Use(middleware.RequestID())
|
||||
r.Use(middleware.Recovery(zapLogger))
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupStaticFiles(r)
|
||||
|
||||
server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动",
|
||||
zap.String("addr", server.Addr),
|
||||
zap.String("version", buildinfo.Version()),
|
||||
zap.String("commit", buildinfo.Commit()),
|
||||
zap.String("build_time", buildinfo.BuildTime()))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("无法打开浏览器", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
setupSystray(port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
providers.POST("", providerHandler.CreateProvider)
|
||||
providers.GET("/:id", providerHandler.GetProvider)
|
||||
providers.PUT("/:id", providerHandler.UpdateProvider)
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
models.POST("", modelHandler.CreateModel)
|
||||
models.GET("/:id", modelHandler.GetModel)
|
||||
models.PUT("/:id", modelHandler.UpdateModel)
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
|
||||
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
|
||||
next(c)
|
||||
}
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
|
||||
}
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
}
|
||||
|
||||
func frontendDistFS() (fs.FS, error) {
|
||||
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
}
|
||||
|
||||
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
if strings.HasSuffix(path, ".png") {
|
||||
return "image/png"
|
||||
}
|
||||
if strings.HasSuffix(path, ".ico") {
|
||||
return "image/x-icon"
|
||||
}
|
||||
if strings.HasSuffix(path, ".woff") || strings.HasSuffix(path, ".woff2") {
|
||||
return "font/woff2"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/openai/") ||
|
||||
strings.HasPrefix(path, "/anthropic/") ||
|
||||
path == "/openai" ||
|
||||
path == "/anthropic" ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
}
|
||||
|
||||
func setupSystray(port int) {
|
||||
systray.Run(func() {
|
||||
var icon []byte
|
||||
var err error
|
||||
if runtime.GOOS == "windows" {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
|
||||
} else {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
|
||||
}
|
||||
if err != nil {
|
||||
zapLogger.Error("无法加载托盘图标", zap.Error(err))
|
||||
}
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTooltip(appTooltip)
|
||||
|
||||
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||
systray.AddSeparator()
|
||||
mStatus := systray.AddMenuItem("状态: 运行中", "")
|
||||
mStatus.Disable()
|
||||
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
|
||||
mPort.Disable()
|
||||
systray.AddSeparator()
|
||||
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("打开浏览器失败", zap.Error(err))
|
||||
}
|
||||
case <-mQuit.ClickedCh:
|
||||
doShutdown()
|
||||
systray.Quit()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func doShutdown() {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("正在关闭服务器...")
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
|
||||
zapLogger.Warn("关闭服务器失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if shutdownCancel != nil {
|
||||
shutdownCancel()
|
||||
}
|
||||
}
|
||||
|
||||
func checkPortAvailable(port int) error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("端口 %d 已被占用\n\n可能原因:\n- 已有 Nex 实例运行\n- 其他程序占用了该端口\n\n请检查并关闭占用端口的程序", port)
|
||||
}
|
||||
ln.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type SingletonLock struct {
|
||||
flock *flock.Flock
|
||||
}
|
||||
|
||||
func NewSingletonLock(lockPath string) *SingletonLock {
|
||||
return &SingletonLock{
|
||||
flock: flock.New(lockPath),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Lock() error {
|
||||
locked, err := s.flock.TryLock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !locked {
|
||||
return fmt.Errorf("已有实例运行")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Unlock() error {
|
||||
return s.flock.Unlock()
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
case "windows":
|
||||
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||
case "linux":
|
||||
browsers := []string{"xdg-open", "google-chrome", "firefox"}
|
||||
for _, browser := range browsers {
|
||||
if _, err := exec.LookPath(browser); err == nil {
|
||||
cmd = exec.Command(browser, url)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cmd == nil {
|
||||
return fmt.Errorf("无法打开浏览器")
|
||||
}
|
||||
|
||||
return cmd.Start()
|
||||
}
|
||||
61
backend/cmd/desktop/messagebox_test.go
Normal file
61
backend/cmd/desktop/messagebox_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
|
||||
t.Helper()
|
||||
|
||||
old := callMessageBoxW
|
||||
callMessageBoxW = fn
|
||||
t.Cleanup(func() {
|
||||
callMessageBoxW = old
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
|
||||
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
|
||||
if err == nil {
|
||||
t.Fatal("包含 NUL 字符时应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 1, syscall.Errno(123)
|
||||
})
|
||||
|
||||
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
|
||||
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
err := messageBox("测试标题", "测试消息", mbIconInformation)
|
||||
if !errors.Is(err, syscall.Errno(5)) {
|
||||
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowError_WindowsBranch(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
|
||||
}
|
||||
}()
|
||||
|
||||
showError("测试错误", "这是一条测试错误消息")
|
||||
}
|
||||
9
backend/cmd/desktop/metadata.go
Normal file
9
backend/cmd/desktop/metadata.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package main
|
||||
|
||||
const (
|
||||
appName = "Nex"
|
||||
appTooltip = appName
|
||||
appDescription = "AI Gateway - 统一的大模型 API 网关"
|
||||
// #nosec G101 -- 项目官网地址不是凭据
|
||||
appWebsite = "https://github.com/nex/gateway"
|
||||
)
|
||||
13
backend/cmd/desktop/metadata_test.go
Normal file
13
backend/cmd/desktop/metadata_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDesktopMetadata(t *testing.T) {
|
||||
if appName != "Nex" {
|
||||
t.Fatalf("appName = %q, want %q", appName, "Nex")
|
||||
}
|
||||
|
||||
if appTooltip != appName {
|
||||
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
|
||||
}
|
||||
}
|
||||
69
backend/cmd/desktop/port_test.go
Normal file
69
backend/cmd/desktop/port_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCheckPortAvailable(t *testing.T) {
|
||||
port := 19826
|
||||
|
||||
err := checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口 %d 应该可用: %v", port, err)
|
||||
}
|
||||
|
||||
t.Log("端口可用测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err == nil {
|
||||
t.Fatal("端口被占用时应该返回错误")
|
||||
}
|
||||
|
||||
t.Log("端口占用检测测试通过")
|
||||
}
|
||||
|
||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
||||
port := 19828
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:19828")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{ReadHeaderTimeout: time.Second}
|
||||
defer server.Close()
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("serve failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
listener.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
if err != nil {
|
||||
t.Fatalf("端口关闭后应该可用: %v", err)
|
||||
}
|
||||
|
||||
t.Log("端口关闭后可用测试通过")
|
||||
}
|
||||
74
backend/cmd/desktop/singleton_test.go
Normal file
74
backend/cmd/desktop/singleton_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSingletonLock_FirstLockSuccess(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-first.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock := NewSingletonLock(lockPath)
|
||||
if err := lock.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-dup.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock1 := NewSingletonLock(lockPath)
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
err := lock2.Lock()
|
||||
if err == nil {
|
||||
if unlockErr := lock2.Unlock(); unlockErr != nil {
|
||||
t.Fatalf("解锁失败: %v", unlockErr)
|
||||
}
|
||||
t.Fatal("重复加锁应失败,但返回 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockThenRelock(t *testing.T) {
|
||||
lockPath := filepath.Join(os.TempDir(), "nex-gateway-test-relock.lock")
|
||||
defer os.Remove(lockPath)
|
||||
|
||||
lock1 := NewSingletonLock(lockPath)
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
if err := lock2.Lock(); err != nil {
|
||||
t.Fatalf("释放后重新加锁应成功: %v", err)
|
||||
}
|
||||
if err := lock2.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
|
||||
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("未加锁时解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
213
backend/cmd/desktop/static_test.go
Normal file
213
backend/cmd/desktop/static_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("API 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 404 {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/anthropic/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for JS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "application/javascript"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MIME type for CSS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/assets/test.css", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == 200 {
|
||||
expected := "text/css"
|
||||
if w.Header().Get("Content-Type") != expected {
|
||||
t.Errorf("期望 Content-Type %s, 实际 %s", expected, w.Header().Get("Content-Type"))
|
||||
}
|
||||
} else {
|
||||
t.Log("文件不存在,跳过 MIME 类型验证")
|
||||
}
|
||||
})
|
||||
|
||||
t.Log("静态文件服务测试通过")
|
||||
}
|
||||
|
||||
func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var gotProtocol string
|
||||
var gotPath string
|
||||
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" {
|
||||
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/chat/completions" {
|
||||
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "anthropic" {
|
||||
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/messages" {
|
||||
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Static assets are not hijacked", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if gotProtocol != "" || gotPath != "" {
|
||||
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
|
||||
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
return
|
||||
}
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA path keeps fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
|
||||
t.Errorf("期望返回 HTML,实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" || gotPath != "/unknown" {
|
||||
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,43 +3,38 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/database"
|
||||
"nex/backend/internal/handler"
|
||||
"nex/backend/internal/handler/middleware"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/buildinfo"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 1. 加载配置
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
log.Fatalf("配置验证失败: %v", err)
|
||||
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 2. 初始化日志
|
||||
zapLogger, err := pkgLogger.New(pkgLogger.Config{
|
||||
zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
@@ -48,67 +43,85 @@ func main() {
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化日志失败: %v", err)
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// 3. 初始化数据库
|
||||
db, err := initDatabase(cfg)
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer closeDB(db)
|
||||
defer database.Close(db)
|
||||
|
||||
// 4. 初始化 repository 层
|
||||
providerRepo := repository.NewProviderRepository(db)
|
||||
modelRepo := repository.NewModelRepository(db)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
|
||||
// 5. 初始化 service 层
|
||||
providerService := service.NewProviderService(providerRepo)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo)
|
||||
routingService := service.NewRoutingService(modelRepo, providerRepo)
|
||||
statsService := service.NewStatsService(statsRepo)
|
||||
routingCache := service.NewRoutingCache(modelRepo, providerRepo, zapLogger)
|
||||
if err := routingCache.Preload(); err != nil {
|
||||
zapLogger.Warn("缓存预热失败,将使用懒加载", zap.Error(err))
|
||||
}
|
||||
|
||||
// 6. 初始化 provider client
|
||||
providerClient := provider.NewClient()
|
||||
statsBuffer := service.NewStatsBuffer(statsRepo, zapLogger,
|
||||
service.WithFlushInterval(5*time.Second),
|
||||
service.WithFlushThreshold(100))
|
||||
statsBuffer.Start()
|
||||
|
||||
// 7. 初始化 handler 层
|
||||
openaiHandler := handler.NewOpenAIHandler(providerClient, routingService, statsService)
|
||||
anthropicHandler := handler.NewAnthropicHandler(providerClient, routingService, statsService)
|
||||
providerService := service.NewProviderService(providerRepo, modelRepo, routingCache)
|
||||
modelService := service.NewModelService(modelRepo, providerRepo, routingCache)
|
||||
routingService := service.NewRoutingService(routingCache)
|
||||
statsService := service.NewStatsService(statsRepo, statsBuffer)
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
|
||||
// 8. 创建 Gin 引擎
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
|
||||
// 注册中间件(按正确顺序)
|
||||
r.Use(middleware.RequestID())
|
||||
r.Use(middleware.Recovery(zapLogger))
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
setupRoutes(r, openaiHandler, anthropicHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
|
||||
// 9. 启动服务器
|
||||
srv := &http.Server{
|
||||
Addr: formatAddr(cfg.Server.Port),
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
Handler: r,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
}
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
|
||||
zapLogger.Info("AI Gateway 启动",
|
||||
zap.String("addr", srv.Addr),
|
||||
zap.String("version", buildinfo.Version()),
|
||||
zap.String("commit", buildinfo.Commit()),
|
||||
zap.String("build_time", buildinfo.BuildTime()))
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
@@ -119,100 +132,17 @@ func main() {
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器强制关闭", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer.Stop()
|
||||
|
||||
zapLogger.Info("服务器已关闭")
|
||||
}
|
||||
|
||||
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
|
||||
db, err := gorm.Open(sqlite.Open(cfg.Database.Path), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
|
||||
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
}
|
||||
|
||||
// 运行数据库迁移
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.Database.ConnMaxLifetime)
|
||||
|
||||
// 记录连接池状态
|
||||
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
cfg.Database.MaxIdleConns, cfg.Database.MaxOpenConns, cfg.Database.ConnMaxLifetime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations 使用 goose 执行数据库迁移
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir()
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
goose.SetDialect("sqlite3")
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getMigrationsDir 获取迁移文件目录路径
|
||||
func getMigrationsDir() string {
|
||||
// 从可执行文件位置推断迁移目录
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
// cmd/server/main.go → backend/ → backend/migrations/
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations")
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
// 回退到相对路径
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
func closeDB(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func formatAddr(port int) string {
|
||||
return fmt.Sprintf(":%d", port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicHandler *handler.AnthropicHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
// OpenAI 协议代理
|
||||
r.POST("/v1/chat/completions", openaiHandler.HandleChatCompletions)
|
||||
|
||||
// Anthropic 协议代理
|
||||
r.POST("/v1/messages", anthropicHandler.HandleMessages)
|
||||
|
||||
// 供应商管理 API
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
providers.GET("", providerHandler.ListProviders)
|
||||
@@ -222,7 +152,6 @@ func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicH
|
||||
providers.DELETE("/:id", providerHandler.DeleteProvider)
|
||||
}
|
||||
|
||||
// 模型管理 API
|
||||
models := r.Group("/api/models")
|
||||
{
|
||||
models.GET("", modelHandler.ListModels)
|
||||
@@ -232,14 +161,12 @@ func setupRoutes(r *gin.Engine, openaiHandler *handler.OpenAIHandler, anthropicH
|
||||
models.DELETE("/:id", modelHandler.DeleteModel)
|
||||
}
|
||||
|
||||
// 统计查询 API
|
||||
stats := r.Group("/api/stats")
|
||||
{
|
||||
stats.GET("", statsHandler.GetStats)
|
||||
stats.GET("/aggregate", statsHandler.AggregateStats)
|
||||
}
|
||||
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
119
backend/cmd/versionctl/main.go
Normal file
119
backend/cmd/versionctl/main.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"nex/backend/pkg/projectversion"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := run(os.Args[1:]); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run(args []string) error {
|
||||
if len(args) == 0 {
|
||||
return usageError()
|
||||
}
|
||||
|
||||
root, err := projectversion.FindRepoRoot(mustGetwd())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case "print":
|
||||
version, readErr := projectversion.ReadString(root)
|
||||
if readErr != nil {
|
||||
return readErr
|
||||
}
|
||||
fmt.Println(version)
|
||||
return nil
|
||||
case "sync":
|
||||
return projectversion.Sync(root)
|
||||
case "check":
|
||||
return projectversion.Check(root)
|
||||
case "verify-tag":
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("verify-tag 需要一个 tag 参数")
|
||||
}
|
||||
return projectversion.VerifyTag(root, args[1])
|
||||
case "macos-plist":
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("macos-plist 需要一个最低系统版本参数")
|
||||
}
|
||||
return printMacOSPlist(root, args[1])
|
||||
case "asset-name":
|
||||
return printAssetName(root, args[1:])
|
||||
default:
|
||||
return usageError()
|
||||
}
|
||||
}
|
||||
|
||||
func printMacOSPlist(root, minMacOSVersion string) error {
|
||||
version, err := projectversion.ReadString(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plist, err := projectversion.DesktopInfoPlist(version, minMacOSVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(plist)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printAssetName(root string, args []string) error {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("asset-name 至少需要 kind 和 platform 参数")
|
||||
}
|
||||
|
||||
version, err := projectversion.ReadString(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case "server":
|
||||
if len(args) != 3 {
|
||||
return fmt.Errorf("server 资产命名需要 platform 和 arch 参数")
|
||||
}
|
||||
name, nameErr := projectversion.ServerAssetName(version, args[1], args[2])
|
||||
if nameErr != nil {
|
||||
return nameErr
|
||||
}
|
||||
fmt.Println(name)
|
||||
return nil
|
||||
case "desktop":
|
||||
if len(args) != 2 {
|
||||
return fmt.Errorf("desktop 资产命名只需要 platform 参数")
|
||||
}
|
||||
name, nameErr := projectversion.DesktopAssetName(version, args[1])
|
||||
if nameErr != nil {
|
||||
return nameErr
|
||||
}
|
||||
fmt.Println(name)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("不支持的资产类型 %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func mustGetwd() string {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return wd
|
||||
}
|
||||
|
||||
func usageError() error {
|
||||
return fmt.Errorf("用法: versionctl <print|sync|check|verify-tag|macos-plist|asset-name>")
|
||||
}
|
||||
217
backend/go.mod
217
backend/go.mod
@@ -2,58 +2,249 @@ module nex/backend
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/pressly/goose/v3 v3.27.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/zap v1.27.1
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
replace nex/embedfs => ../embedfs
|
||||
|
||||
tool (
|
||||
github.com/golangci/golangci-lint/cmd/golangci-lint
|
||||
go.uber.org/mock/mockgen
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/getlantern/systray v1.2.2
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/go-playground/validator/v10 v10.30.2
|
||||
github.com/gofrs/flock v0.13.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/pressly/goose/v3 v3.27.0
|
||||
github.com/spf13/pflag v1.0.10
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/mock v0.6.0
|
||||
go.uber.org/zap v1.27.1
|
||||
gopkg.in/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
nex/embedfs v0.0.0-00010101000000-000000000000
|
||||
)
|
||||
|
||||
require (
|
||||
4d63.com/gocheckcompilerdirectives v1.3.0 // indirect
|
||||
4d63.com/gochecknoglobals v0.2.2 // indirect
|
||||
filippo.io/edwards25519 v1.2.0 // indirect
|
||||
github.com/4meepo/tagalign v1.4.2 // indirect
|
||||
github.com/Abirdcfly/dupword v0.1.3 // indirect
|
||||
github.com/Antonboom/errname v1.0.0 // indirect
|
||||
github.com/Antonboom/nilnil v1.0.1 // indirect
|
||||
github.com/Antonboom/testifylint v1.5.2 // indirect
|
||||
github.com/BurntSushi/toml v1.6.0 // indirect
|
||||
github.com/Crocmagnon/fatcontext v0.7.1 // indirect
|
||||
github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect
|
||||
github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.3.0 // indirect
|
||||
github.com/OpenPeeDeeP/depguard/v2 v2.2.1 // indirect
|
||||
github.com/alecthomas/go-check-sumtype v0.3.1 // indirect
|
||||
github.com/alexkohler/nakedret/v2 v2.0.5 // indirect
|
||||
github.com/alexkohler/prealloc v1.0.0 // indirect
|
||||
github.com/alingse/asasalint v0.0.11 // indirect
|
||||
github.com/alingse/nilnesserr v0.1.2 // indirect
|
||||
github.com/ashanbrown/forbidigo v1.6.0 // indirect
|
||||
github.com/ashanbrown/makezero v1.2.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bkielbasa/cyclop v1.2.3 // indirect
|
||||
github.com/blizzy78/varnamelen v0.8.0 // indirect
|
||||
github.com/bombsimon/wsl/v4 v4.5.0 // indirect
|
||||
github.com/breml/bidichk v0.3.2 // indirect
|
||||
github.com/breml/errchkjson v0.4.0 // indirect
|
||||
github.com/butuzov/ireturn v0.3.1 // indirect
|
||||
github.com/butuzov/mirror v1.3.0 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/catenacyber/perfsprint v0.8.2 // indirect
|
||||
github.com/ccojocar/zxcvbn-go v1.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charithe/durationcheck v0.0.10 // indirect
|
||||
github.com/chavacava/garif v0.1.0 // indirect
|
||||
github.com/ckaznocha/intrange v0.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/curioswitch/go-reassign v0.3.0 // indirect
|
||||
github.com/daixiang0/gci v0.13.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/denis-tingaikin/go-header v0.5.0 // indirect
|
||||
github.com/ettle/strcase v0.2.0 // indirect
|
||||
github.com/fatih/color v1.18.0 // indirect
|
||||
github.com/fatih/structtag v1.2.0 // indirect
|
||||
github.com/firefart/nonamedreturns v1.0.5 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/fzipp/gocyclo v0.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
|
||||
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
|
||||
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||
github.com/ghostiam/protogetter v0.3.9 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-critic/go-critic v0.12.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.2 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.3 // indirect
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/go-toolsmith/astcast v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astcopy v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astequal v1.2.0 // indirect
|
||||
github.com/go-toolsmith/astfmt v1.1.0 // indirect
|
||||
github.com/go-toolsmith/astp v1.1.0 // indirect
|
||||
github.com/go-toolsmith/strparse v1.1.0 // indirect
|
||||
github.com/go-toolsmith/typep v1.1.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect
|
||||
github.com/gobwas/glob v0.2.3 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/golangci/dupl v0.0.0-20250308024227-f665c8d69b32 // indirect
|
||||
github.com/golangci/go-printf-func-name v0.1.0 // indirect
|
||||
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
|
||||
github.com/golangci/golangci-lint v1.64.8 // indirect
|
||||
github.com/golangci/misspell v0.6.0 // indirect
|
||||
github.com/golangci/plugin-module-register v0.1.1 // indirect
|
||||
github.com/golangci/revgrep v0.8.0 // indirect
|
||||
github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/gordonklaus/ineffassign v0.1.0 // indirect
|
||||
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
|
||||
github.com/gostaticanalysis/comment v1.5.0 // indirect
|
||||
github.com/gostaticanalysis/forcetypeassert v0.2.0 // indirect
|
||||
github.com/gostaticanalysis/nilerr v0.1.1 // indirect
|
||||
github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect
|
||||
github.com/hashicorp/go-version v1.7.0 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
github.com/hexops/gotextdiff v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jgautheron/goconst v1.7.1 // indirect
|
||||
github.com/jingyugao/rowserrcheck v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jjti/go-spancheck v0.6.4 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/julz/importas v0.2.0 // indirect
|
||||
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
|
||||
github.com/kisielk/errcheck v1.9.0 // indirect
|
||||
github.com/kkHAIKE/contextcheck v1.1.6 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/kulti/thelper v0.6.3 // indirect
|
||||
github.com/kunwardeep/paralleltest v1.0.10 // indirect
|
||||
github.com/lasiar/canonicalheader v1.1.2 // indirect
|
||||
github.com/ldez/exptostd v0.4.2 // indirect
|
||||
github.com/ldez/gomoddirectives v0.6.1 // indirect
|
||||
github.com/ldez/grignotin v0.9.0 // indirect
|
||||
github.com/ldez/tagliatelle v0.7.1 // indirect
|
||||
github.com/ldez/usetesting v0.4.2 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/leonklingele/grouper v1.1.2 // indirect
|
||||
github.com/macabu/inamedparam v0.1.3 // indirect
|
||||
github.com/maratori/testableexamples v1.0.0 // indirect
|
||||
github.com/maratori/testpackage v1.1.1 // indirect
|
||||
github.com/matoous/godox v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||
github.com/mgechev/revive v1.7.0 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/moricho/tparallel v0.3.2 // indirect
|
||||
github.com/nakabonne/nestif v0.3.1 // indirect
|
||||
github.com/nishanths/exhaustive v0.12.0 // indirect
|
||||
github.com/nishanths/predeclared v0.2.2 // indirect
|
||||
github.com/nunnatsa/ginkgolinter v0.19.1 // indirect
|
||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
|
||||
github.com/prometheus/client_golang v1.12.1 // indirect
|
||||
github.com/prometheus/client_model v0.2.0 // indirect
|
||||
github.com/prometheus/common v0.32.1 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
|
||||
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
|
||||
github.com/quasilyte/gogrep v0.5.0 // indirect
|
||||
github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect
|
||||
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/raeperd/recvcheck v0.2.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/ryancurrah/gomodguard v1.3.5 // indirect
|
||||
github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect
|
||||
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
|
||||
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
|
||||
github.com/securego/gosec/v2 v2.22.2 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sivchari/containedctx v1.0.3 // indirect
|
||||
github.com/sivchari/tenv v1.12.1 // indirect
|
||||
github.com/sonatard/noctx v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/sourcegraph/go-diff v0.7.0 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/cobra v1.9.1 // indirect
|
||||
github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect
|
||||
github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tdakkota/asciicheck v0.4.1 // indirect
|
||||
github.com/tetafro/godot v1.5.0 // indirect
|
||||
github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect
|
||||
github.com/timonwong/loggercheck v0.10.1 // indirect
|
||||
github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect
|
||||
github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
github.com/ultraware/funlen v0.2.0 // indirect
|
||||
github.com/ultraware/whitespace v0.2.0 // indirect
|
||||
github.com/uudashr/gocognit v1.2.0 // indirect
|
||||
github.com/uudashr/iface v1.3.1 // indirect
|
||||
github.com/xen0n/gosmopolitan v1.2.2 // indirect
|
||||
github.com/yagipy/maintidx v1.0.0 // indirect
|
||||
github.com/yeya24/promlinter v0.3.0 // indirect
|
||||
github.com/ykadowak/zerologlint v0.1.5 // indirect
|
||||
gitlab.com/bosi/decorder v0.4.2 // indirect
|
||||
go-simpler.org/musttag v0.13.0 // indirect
|
||||
go-simpler.org/sloglint v0.9.0 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp/typeparams v0.0.0-20250210185358-939b2ce775ac // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated // indirect
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
honnef.co/go/tools v0.6.1 // indirect
|
||||
mvdan.cc/gofumpt v0.7.0 // indirect
|
||||
mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect
|
||||
)
|
||||
|
||||
961
backend/go.sum
961
backend/go.sum
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,18 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
@@ -13,40 +20,49 @@ import (
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
Server ServerConfig `yaml:"server" mapstructure:"server" validate:"required"`
|
||||
Database DatabaseConfig `yaml:"database" mapstructure:"database" validate:"required"`
|
||||
Log LogConfig `yaml:"log" mapstructure:"log" validate:"required"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
type ServerConfig struct {
|
||||
Port int `yaml:"port"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
Port int `yaml:"port" mapstructure:"port" validate:"required,min=1,max=65535"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout" mapstructure:"read_timeout" validate:"required"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout" mapstructure:"write_timeout" validate:"required"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime"`
|
||||
Driver string `yaml:"driver" mapstructure:"driver" validate:"required,oneof=sqlite mysql"`
|
||||
Path string `yaml:"path" mapstructure:"path" validate:"required_if=Driver sqlite"`
|
||||
Host string `yaml:"host" mapstructure:"host" validate:"required_if=Driver mysql"`
|
||||
Port int `yaml:"port" mapstructure:"port" validate:"required_if=Driver mysql,min=1,max=65535"`
|
||||
User string `yaml:"user" mapstructure:"user" validate:"required_if=Driver mysql"`
|
||||
Password string `yaml:"password" mapstructure:"password"`
|
||||
DBName string `yaml:"dbname" mapstructure:"dbname" validate:"required_if=Driver mysql"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" validate:"required,min=1"`
|
||||
MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" validate:"required,min=1"`
|
||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" mapstructure:"conn_max_lifetime" validate:"required"`
|
||||
}
|
||||
|
||||
// LogConfig 日志配置
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Path string `yaml:"path"`
|
||||
MaxSize int `yaml:"max_size"`
|
||||
MaxBackups int `yaml:"max_backups"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
Compress bool `yaml:"compress"`
|
||||
Level string `yaml:"level" mapstructure:"level" validate:"required,oneof=debug info warn error"`
|
||||
Path string `yaml:"path" mapstructure:"path" validate:"required"`
|
||||
MaxSize int `yaml:"max_size" mapstructure:"max_size" validate:"required,min=1"`
|
||||
MaxBackups int `yaml:"max_backups" mapstructure:"max_backups" validate:"required,min=0"`
|
||||
MaxAge int `yaml:"max_age" mapstructure:"max_age" validate:"required,min=0"`
|
||||
Compress bool `yaml:"compress" mapstructure:"compress"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns default config values
|
||||
func DefaultConfig() *Config {
|
||||
// Use home dir for default paths
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
return &Config{
|
||||
@@ -56,7 +72,13 @@ func DefaultConfig() *Config {
|
||||
WriteTimeout: 30 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(nexDir, "config.db"),
|
||||
Host: "",
|
||||
Port: 3306,
|
||||
User: "",
|
||||
Password: "",
|
||||
DBName: "nex",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 1 * time.Hour,
|
||||
@@ -79,7 +101,7 @@ func GetConfigDir() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return configDir, nil
|
||||
@@ -103,29 +125,179 @@ func GetConfigPath() (string, error) {
|
||||
return filepath.Join(configDir, "config.yaml"), nil
|
||||
}
|
||||
|
||||
// setupDefaults 设置默认配置值
|
||||
func setupDefaults(v *viper.Viper) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
v.SetDefault("server.port", 9826)
|
||||
v.SetDefault("server.read_timeout", "30s")
|
||||
v.SetDefault("server.write_timeout", "30s")
|
||||
|
||||
v.SetDefault("database.driver", "sqlite")
|
||||
v.SetDefault("database.path", filepath.Join(nexDir, "config.db"))
|
||||
v.SetDefault("database.host", "")
|
||||
v.SetDefault("database.port", 3306)
|
||||
v.SetDefault("database.user", "")
|
||||
v.SetDefault("database.password", "")
|
||||
v.SetDefault("database.dbname", "nex")
|
||||
v.SetDefault("database.max_idle_conns", 10)
|
||||
v.SetDefault("database.max_open_conns", 100)
|
||||
v.SetDefault("database.conn_max_lifetime", "1h")
|
||||
|
||||
v.SetDefault("log.level", "info")
|
||||
v.SetDefault("log.path", filepath.Join(nexDir, "log"))
|
||||
v.SetDefault("log.max_size", 100)
|
||||
v.SetDefault("log.max_backups", 10)
|
||||
v.SetDefault("log.max_age", 30)
|
||||
v.SetDefault("log.compress", true)
|
||||
}
|
||||
|
||||
// setupFlags 定义和绑定 CLI 参数
|
||||
func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
||||
// 定义所有配置项的 CLI 参数
|
||||
// 注意:这里不设置默认值,让 viper 的默认值生效
|
||||
flagSet.Int("server-port", 0, "服务器端口")
|
||||
flagSet.Duration("server-read-timeout", 0, "读超时")
|
||||
flagSet.Duration("server-write-timeout", 0, "写超时")
|
||||
|
||||
flagSet.String("database-driver", "", "数据库驱动:sqlite/mysql")
|
||||
flagSet.String("database-path", "", "数据库文件路径")
|
||||
flagSet.String("database-host", "", "MySQL 主机地址")
|
||||
flagSet.Int("database-port", 0, "MySQL 端口")
|
||||
flagSet.String("database-user", "", "MySQL 用户名")
|
||||
flagSet.String("database-password", "", "MySQL 密码")
|
||||
flagSet.String("database-dbname", "", "MySQL 数据库名")
|
||||
flagSet.Int("database-max-idle-conns", 0, "最大空闲连接数")
|
||||
flagSet.Int("database-max-open-conns", 0, "最大打开连接数")
|
||||
flagSet.Duration("database-conn-max-lifetime", 0, "连接最大生命周期")
|
||||
|
||||
flagSet.String("log-level", "", "日志级别:debug/info/warn/error")
|
||||
flagSet.String("log-path", "", "日志文件目录")
|
||||
flagSet.Int("log-max-size", 0, "单个日志文件最大大小 MB")
|
||||
flagSet.Int("log-max-backups", 0, "保留的旧日志文件最大数量")
|
||||
flagSet.Int("log-max-age", 0, "保留旧日志文件的最大天数")
|
||||
flagSet.Bool("log-compress", false, "是否压缩旧日志文件")
|
||||
|
||||
// 绑定所有 flag 到 viper
|
||||
// 注意:必须在设置默认值之后绑定
|
||||
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
|
||||
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
|
||||
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
|
||||
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
|
||||
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
|
||||
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
|
||||
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
|
||||
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
|
||||
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
|
||||
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
|
||||
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
|
||||
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
|
||||
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
|
||||
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
|
||||
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
|
||||
}
|
||||
|
||||
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
|
||||
if err := v.BindPFlag(key, flag); err != nil {
|
||||
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
|
||||
}
|
||||
}
|
||||
|
||||
// setupEnv 绑定环境变量
|
||||
func setupEnv(v *viper.Viper) {
|
||||
v.SetEnvPrefix("NEX")
|
||||
v.AutomaticEnv()
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
}
|
||||
|
||||
// setupConfigFile 读取配置文件
|
||||
func setupConfigFile(v *viper.Viper, configPath string) error {
|
||||
v.SetConfigFile(configPath)
|
||||
v.SetConfigType("yaml")
|
||||
|
||||
// 尝试读取配置文件,如果不存在则忽略
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
// 配置文件不存在,创建默认配置文件
|
||||
writeErr := v.SafeWriteConfigAs(configPath)
|
||||
if writeErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
|
||||
if errors.As(writeErr, &alreadyExistsErr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadConfig loads config from YAML file, creates default if not exists
|
||||
func LoadConfig() (*Config, error) {
|
||||
configPath, err := GetConfigPath()
|
||||
if err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
return LoadConfigFromPath(configPath)
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
// LoadConfigFromPath 从指定路径加载配置
|
||||
func LoadConfigFromPath(configPath string) (*Config, error) {
|
||||
// 1. 创建 Viper 实例
|
||||
v := viper.New()
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Create default config file
|
||||
if saveErr := SaveConfig(cfg); saveErr != nil {
|
||||
return nil, appErrors.WithMessage(appErrors.ErrInternal, "创建默认配置失败")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
// 2. 定义 CLI 参数
|
||||
flagSet := pflag.NewFlagSet("config", pflag.ContinueOnError)
|
||||
flagSet.String("config", configPath, "配置文件路径")
|
||||
setupFlags(v, flagSet)
|
||||
|
||||
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
|
||||
if err := flagSet.Parse(os.Args[1:]); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
|
||||
}
|
||||
|
||||
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
|
||||
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
|
||||
configPath = configPathFlag
|
||||
}
|
||||
|
||||
// 5. 设置默认值
|
||||
setupDefaults(v)
|
||||
|
||||
// 6. 绑定环境变量
|
||||
setupEnv(v)
|
||||
|
||||
// 7. 读取配置文件
|
||||
if err := setupConfigFile(v, configPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 8. 反序列化到结构体
|
||||
cfg := &Config{}
|
||||
if err := v.Unmarshal(cfg, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.StringToTimeDurationHookFunc(),
|
||||
mapstructure.StringToSliceHookFunc(","),
|
||||
))); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(data, cfg); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
// 9. 验证配置
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
@@ -145,27 +317,41 @@ func SaveConfig(cfg *Config) error {
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return os.WriteFile(configPath, data, 0644)
|
||||
return os.WriteFile(configPath, data, 0o600)
|
||||
}
|
||||
|
||||
// Validate validates the config
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的端口号: %d", c.Server.Port))
|
||||
validate := validator.New()
|
||||
if err := validate.Struct(c); err != nil {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("配置验证失败: %v", err))
|
||||
}
|
||||
|
||||
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
|
||||
if !validLevels[c.Log.Level] {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, fmt.Sprintf("无效的日志级别: %s", c.Log.Level))
|
||||
}
|
||||
|
||||
if c.Database.Path == "" {
|
||||
return appErrors.WithMessage(appErrors.ErrInvalidRequest, "数据库路径不能为空")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrintSummary 打印配置摘要
|
||||
func (c *Config) PrintSummary(logger *zap.Logger) {
|
||||
logger.Info("AI Gateway 启动配置",
|
||||
zap.Int("server_port", c.Server.Port),
|
||||
zap.String("database_driver", c.Database.Driver),
|
||||
zap.String("log_level", c.Log.Level),
|
||||
)
|
||||
|
||||
if c.Database.Driver == "mysql" {
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "mysql"),
|
||||
zap.String("host", c.Database.Host),
|
||||
zap.Int("port", c.Database.Port),
|
||||
zap.String("database", c.Database.DBName),
|
||||
)
|
||||
} else {
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "sqlite"),
|
||||
zap.String("path", c.Database.Path),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,12 @@ func TestDefaultConfig(t *testing.T) {
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
|
||||
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||
assert.Equal(t, "", cfg.Database.Host)
|
||||
assert.Equal(t, 3306, cfg.Database.Port)
|
||||
assert.Equal(t, "", cfg.Database.User)
|
||||
assert.Equal(t, "", cfg.Database.Password)
|
||||
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
@@ -46,13 +53,13 @@ func TestConfig_Validate(t *testing.T) {
|
||||
name: "端口号为0无效",
|
||||
modify: func(c *Config) { c.Server.Port = 0 },
|
||||
wantErr: true,
|
||||
errMsg: "无效的端口号",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "端口号超出范围无效",
|
||||
modify: func(c *Config) { c.Server.Port = 70000 },
|
||||
wantErr: true,
|
||||
errMsg: "无效的端口号",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "端口号为1有效",
|
||||
@@ -68,7 +75,7 @@ func TestConfig_Validate(t *testing.T) {
|
||||
name: "无效日志级别",
|
||||
modify: func(c *Config) { c.Log.Level = "invalid" },
|
||||
wantErr: true,
|
||||
errMsg: "无效的日志级别",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "debug级别有效",
|
||||
@@ -86,10 +93,75 @@ func TestConfig_Validate(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数据库路径为空无效",
|
||||
name: "SQLite模式路径为空无效",
|
||||
modify: func(c *Config) { c.Database.Path = "" },
|
||||
wantErr: true,
|
||||
errMsg: "数据库路径不能为空",
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "driver值不合法",
|
||||
modify: func(c *Config) { c.Database.Driver = "postgres" },
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL配置有效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.Port = 3306
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MySQL模式host为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = ""
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式user为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = ""
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式dbname为空无效",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = ""
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "配置验证失败",
|
||||
},
|
||||
{
|
||||
name: "MySQL模式忽略path字段",
|
||||
modify: func(c *Config) {
|
||||
c.Database.Driver = "mysql"
|
||||
c.Database.Host = "localhost"
|
||||
c.Database.User = "root"
|
||||
c.Database.DBName = "nex"
|
||||
c.Database.Path = ""
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -100,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
|
||||
err := cfg.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -140,7 +214,10 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
WriteTimeout: 20 * time.Second,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
Port: 3306,
|
||||
DBName: "nex",
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 50,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
@@ -159,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
data, err := yaml.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
err = os.WriteFile(configPath, data, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 加载配置
|
||||
@@ -174,3 +251,72 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
assert.Equal(t, cfg.Database.MaxIdleConns, loaded.Database.MaxIdleConns)
|
||||
assert.Equal(t, cfg.Log.Compress, loaded.Log.Compress)
|
||||
}
|
||||
|
||||
func TestCLIConfig(t *testing.T) {
|
||||
// 测试 CLI 参数配置(简化版本)
|
||||
// 注意:由于 flag.Parse 只能调用一次,这里只测试配置加载流程
|
||||
t.Run("配置加载流程", func(t *testing.T) {
|
||||
// 使用默认配置路径测试
|
||||
cfg := DefaultConfig()
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// 验证默认值正确
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
assert.Equal(t, "info", cfg.Log.Level)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnvConfig(t *testing.T) {
|
||||
// 测试环境变量配置(简化版本)
|
||||
t.Run("环境变量前缀", func(t *testing.T) {
|
||||
// 验证环境变量前缀设置正确
|
||||
// 实际的环境变量测试需要独立的进程,这里只验证配置结构
|
||||
cfg := DefaultConfig()
|
||||
require.NotNil(t, cfg)
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigPriority(t *testing.T) {
|
||||
// 测试配置优先级(简化版本)
|
||||
t.Run("默认值设置", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// 验证所有默认值
|
||||
assert.Equal(t, 9826, cfg.Server.Port)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.ReadTimeout)
|
||||
assert.Equal(t, 30*time.Second, cfg.Server.WriteTimeout)
|
||||
assert.Equal(t, "sqlite", cfg.Database.Driver)
|
||||
assert.Equal(t, 3306, cfg.Database.Port)
|
||||
assert.Equal(t, "nex", cfg.Database.DBName)
|
||||
assert.Equal(t, 10, cfg.Database.MaxIdleConns)
|
||||
assert.Equal(t, 100, cfg.Database.MaxOpenConns)
|
||||
assert.Equal(t, 1*time.Hour, cfg.Database.ConnMaxLifetime)
|
||||
assert.Equal(t, "info", cfg.Log.Level)
|
||||
assert.Equal(t, 100, cfg.Log.MaxSize)
|
||||
assert.Equal(t, 10, cfg.Log.MaxBackups)
|
||||
assert.Equal(t, 30, cfg.Log.MaxAge)
|
||||
assert.Equal(t, true, cfg.Log.Compress)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrintSummary(t *testing.T) {
|
||||
t.Run("SQLite模式摘要", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
t.Run("MySQL模式摘要", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Database.Driver = "mysql"
|
||||
cfg.Database.Host = "db.example.com"
|
||||
cfg.Database.Port = 3306
|
||||
cfg.Database.User = "nex"
|
||||
cfg.Database.DBName = "nex"
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,21 +6,22 @@ import (
|
||||
|
||||
// Provider 供应商模型
|
||||
type Provider struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
APIKey string `gorm:"not null" json:"api_key"`
|
||||
BaseURL string `gorm:"not null" json:"base_url"`
|
||||
Protocol string `gorm:"column:protocol;default:'openai'" json:"protocol"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Models []Model `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model 模型配置
|
||||
// Model 模型配置(id 为 UUID 自动生成,UNIQUE(provider_id, model_name))
|
||||
type Model struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model" json:"model_name"`
|
||||
Enabled bool `gorm:"default:true" json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
@@ -28,8 +29,8 @@ type Model struct {
|
||||
// UsageStats 用量统计
|
||||
type UsageStats struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
|
||||
RequestCount int `gorm:"default:0" json:"request_count"`
|
||||
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
||||
}
|
||||
@@ -46,12 +47,3 @@ func (Model) TableName() string {
|
||||
func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
106
backend/internal/conversion/adapter.go
Normal file
106
backend/internal/conversion/adapter.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// ProtocolAdapter 协议适配器接口
|
||||
type ProtocolAdapter interface {
|
||||
ProtocolName() string
|
||||
ProtocolVersion() string
|
||||
SupportsPassthrough() bool
|
||||
|
||||
DetectInterfaceType(nativePath string) InterfaceType
|
||||
BuildUrl(nativePath string, interfaceType InterfaceType) string
|
||||
BuildHeaders(provider *TargetProvider) map[string]string
|
||||
SupportsInterface(interfaceType InterfaceType) bool
|
||||
|
||||
DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error)
|
||||
EncodeRequest(req *canonical.CanonicalRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error)
|
||||
EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error)
|
||||
|
||||
CreateStreamDecoder() StreamDecoder
|
||||
CreateStreamEncoder() StreamEncoder
|
||||
|
||||
EncodeError(err *ConversionError) ([]byte, int)
|
||||
|
||||
DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error)
|
||||
EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error)
|
||||
DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error)
|
||||
EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error)
|
||||
DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error)
|
||||
EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error)
|
||||
EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error)
|
||||
DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error)
|
||||
EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error)
|
||||
DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error)
|
||||
EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error)
|
||||
|
||||
// 统一模型 ID 相关方法
|
||||
ExtractUnifiedModelID(nativePath string) (string, error)
|
||||
ExtractModelName(body []byte, ifaceType InterfaceType) (string, error)
|
||||
RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error)
|
||||
}
|
||||
|
||||
// AdapterRegistry 适配器注册表接口
|
||||
type AdapterRegistry interface {
|
||||
Register(adapter ProtocolAdapter) error
|
||||
Get(protocolName string) (ProtocolAdapter, error)
|
||||
ListProtocols() []string
|
||||
}
|
||||
|
||||
// memoryRegistry 基于内存的适配器注册表
|
||||
type memoryRegistry struct {
|
||||
mu sync.RWMutex
|
||||
adapters map[string]ProtocolAdapter
|
||||
}
|
||||
|
||||
// NewMemoryRegistry 创建内存注册表
|
||||
func NewMemoryRegistry() AdapterRegistry {
|
||||
return &memoryRegistry{
|
||||
adapters: make(map[string]ProtocolAdapter),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册适配器
|
||||
func (r *memoryRegistry) Register(adapter ProtocolAdapter) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
name := adapter.ProtocolName()
|
||||
if _, exists := r.adapters[name]; exists {
|
||||
return fmt.Errorf("适配器已注册: %s", name)
|
||||
}
|
||||
r.adapters[name] = adapter
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取适配器
|
||||
func (r *memoryRegistry) Get(protocolName string) (ProtocolAdapter, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
adapter, ok := r.adapters[protocolName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("未找到适配器: %s", protocolName)
|
||||
}
|
||||
return adapter, nil
|
||||
}
|
||||
|
||||
// ListProtocols 列出所有已注册协议
|
||||
func (r *memoryRegistry) ListProtocols() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
protocols := make([]string, 0, len(r.adapters))
|
||||
for name := range r.adapters {
|
||||
protocols = append(protocols, name)
|
||||
}
|
||||
return protocols
|
||||
}
|
||||
288
backend/internal/conversion/anthropic/adapter.go
Normal file
288
backend/internal/conversion/anthropic/adapter.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// Adapter Anthropic 协议适配器
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 Anthropic 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
// ProtocolName 返回协议名称
|
||||
func (a *Adapter) ProtocolName() string { return "anthropic" }
|
||||
|
||||
// ProtocolVersion 返回协议版本
|
||||
func (a *Adapter) ProtocolVersion() string { return "2023-06-01" }
|
||||
|
||||
// SupportsPassthrough 支持同协议透传
|
||||
func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/v1/messages":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/v1/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case isModelInfoPath(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/v1/messages"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/v1/models"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHeaders 构建请求头
|
||||
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
|
||||
headers := map[string]string{
|
||||
"x-api-key": provider.APIKey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if v, ok := provider.AdapterConfig["anthropic_version"].(string); ok && v != "" {
|
||||
headers["anthropic-version"] = v
|
||||
}
|
||||
if betas, ok := provider.AdapterConfig["anthropic_beta"].([]string); ok && len(betas) > 0 {
|
||||
headers["anthropic-beta"] = strings.Join(betas, ",")
|
||||
} else if betas, ok := provider.AdapterConfig["anthropic_beta"].([]any); ok && len(betas) > 0 {
|
||||
var parts []string
|
||||
for _, b := range betas {
|
||||
if s, ok := b.(string); ok {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
headers["anthropic-beta"] = strings.Join(parts, ",")
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsInterface 检查是否支持接口类型
|
||||
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat,
|
||||
conversion.InterfaceTypeModels,
|
||||
conversion.InterfaceTypeModelInfo:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeRequest 解码请求
|
||||
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return decodeRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRequest 编码请求
|
||||
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeResponse 解码响应
|
||||
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return decodeResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeResponse 编码响应
|
||||
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return encodeResponse(resp)
|
||||
}
|
||||
|
||||
// CreateStreamDecoder 创建流式解码器
|
||||
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
|
||||
return NewStreamDecoder()
|
||||
}
|
||||
|
||||
// CreateStreamEncoder 创建流式编码器
|
||||
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
|
||||
return NewStreamEncoder()
|
||||
}
|
||||
|
||||
// EncodeError 编码错误
|
||||
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
errType := string(err.Code)
|
||||
statusCode := 500
|
||||
|
||||
errMsg := ErrorResponse{
|
||||
Type: "error",
|
||||
Error: ErrorDetail{
|
||||
Type: errType,
|
||||
Message: err.Message,
|
||||
},
|
||||
}
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
// DecodeModelsResponse 解码模型列表响应
|
||||
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return decodeModelsResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelsResponse 编码模型列表响应
|
||||
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return encodeModelsResponse(list)
|
||||
}
|
||||
|
||||
// DecodeModelInfoResponse 解码模型详情响应
|
||||
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return decodeModelInfoResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelInfoResponse 编码模型详情响应
|
||||
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return encodeModelInfoResponse(info)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingRequest Anthropic 不支持嵌入
|
||||
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// EncodeEmbeddingRequest Anthropic 不支持嵌入
|
||||
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// DecodeEmbeddingResponse Anthropic 不支持嵌入
|
||||
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// EncodeEmbeddingResponse Anthropic 不支持嵌入
|
||||
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Embeddings 接口")
|
||||
}
|
||||
|
||||
// DecodeRerankRequest Anthropic 不支持重排序
|
||||
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// EncodeRerankRequest Anthropic 不支持重排序
|
||||
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// DecodeRerankResponse Anthropic 不支持重排序
|
||||
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// EncodeRerankResponse Anthropic 不支持重排序
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInterfaceNotSupported, "Anthropic 不支持 Rerank 接口")
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
233
backend/internal/conversion/anthropic/adapter_test.go
Normal file
233
backend/internal/conversion/anthropic/adapter_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_ProtocolName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "anthropic", a.ProtocolName())
|
||||
}
|
||||
|
||||
func TestAdapter_ProtocolVersion(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "2023-06-01", a.ProtocolVersion())
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsPassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.True(t, a.SupportsPassthrough())
|
||||
}
|
||||
|
||||
func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天消息", "/v1/messages", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/claude-3", conversion.InterfaceTypeModelInfo},
|
||||
{"未知路径", "/v1/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.DetectInterfaceType(tt.path)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
// docs/api_reference/anthropic defines messages and models under /v1.
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/v1/messages", conversion.InterfaceTypeChat},
|
||||
{"/v1/models", conversion.InterfaceTypeModels},
|
||||
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
|
||||
{"/messages", conversion.InterfaceTypePassthrough},
|
||||
{"/models", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/messages", conversion.InterfaceTypeChat, "/v1/messages"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/v1/models"},
|
||||
{"默认透传", "/v1/other", conversion.InterfaceTypePassthrough, "/v1/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.BuildUrl(tt.nativePath, tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_Basic(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "sk-ant-test", headers["x-api-key"])
|
||||
assert.Equal(t, "2023-06-01", headers["anthropic-version"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_CustomVersion(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
provider.AdapterConfig["anthropic_version"] = "2024-01-01"
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "2024-01-01", headers["anthropic-version"])
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders_AnthropicBeta(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "sk-ant-test", "claude-3")
|
||||
provider.AdapterConfig["anthropic_beta"] = []string{"prompt-caching-2024-07-31", "max-tokens-3-5-sonnet-2024-07-15"}
|
||||
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15", headers["anthropic-beta"])
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
{"模型详情", conversion.InterfaceTypeModelInfo, true},
|
||||
{"嵌入", conversion.InterfaceTypeEmbeddings, false},
|
||||
{"重排序", conversion.InterfaceTypeRerank, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.SupportsInterface(tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "error", resp.Type)
|
||||
assert.Equal(t, "INVALID_INPUT", resp.Error.Type)
|
||||
assert.Equal(t, "参数无效", resp.Error.Message)
|
||||
}
|
||||
|
||||
func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入请求", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.True(t, errors.As(err, &convErr))
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("解码重排序请求", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序请求", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码重排序响应", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序响应", func(t *testing.T) {
|
||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||
require.Error(t, err)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
263
backend/internal/conversion/anthropic/adapter_unified_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/anthropic/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/some/deep/nested/model")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "some/deep/nested/model", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/claude-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/messages")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", model)
|
||||
})
|
||||
|
||||
t.Run("chat_with_max_tokens", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3-opus","messages":[{"role":"user","content":"hi"}],"max_tokens":1024}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3-opus", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type_rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3","max_tokens":1024,"temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "claude-3", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
|
||||
// max_tokens is encoded as float in JSON numbers
|
||||
maxTokens, ok := m["max_tokens"]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(1024), maxTokens)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"anthropic/claude-3"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "claude-3", conversion.InterfaceTypeEmbeddings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName (Chat only for Anthropic)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
|
||||
// other fields preserved
|
||||
_, hasContent := m["content"]
|
||||
assert.True(t, hasContent)
|
||||
assert.Equal(t, "end_turn", m["stop_reason"])
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field_adds_it", func(t *testing.T) {
|
||||
body := []byte(`{"content":[],"stop_reason":"end_turn"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "anthropic/claude-3", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "anthropic/claude-3", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"anthropic/claude-3","messages":[],"max_tokens":1024}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "anthropic/claude-3", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "claude-3", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, float64(1024), m["max_tokens"])
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/v1/models/claude-3", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/anthropic/claude-3", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"messages_path", "/v1/messages", false},
|
||||
{"deeply_nested", "/v1/models/org/workspace/claude-3-opus", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
471
backend/internal/conversion/anthropic/decoder.go
Normal file
471
backend/internal/conversion/anthropic/decoder.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// decodeRequest 将 Anthropic 请求解码为 Canonical 请求
|
||||
func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
var req MessagesRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 请求失败").WithCause(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空")
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空")
|
||||
}
|
||||
|
||||
system := decodeSystem(req.System)
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range req.Messages {
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
tools := decodeTools(req.Tools)
|
||||
toolChoice := decodeToolChoice(req.ToolChoice)
|
||||
params := decodeParameters(&req)
|
||||
thinking := decodeThinking(req.Thinking, req.OutputConfig)
|
||||
outputFormat := decodeOutputFormat(req.OutputConfig)
|
||||
|
||||
var parallelToolUse *bool
|
||||
if req.DisableParallelToolUse != nil && *req.DisableParallelToolUse {
|
||||
val := false
|
||||
parallelToolUse = &val
|
||||
}
|
||||
|
||||
var userID string
|
||||
if req.Metadata != nil {
|
||||
userID = req.Metadata.UserID
|
||||
}
|
||||
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: req.Model,
|
||||
System: system,
|
||||
Messages: canonicalMsgs,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
Parameters: params,
|
||||
Thinking: thinking,
|
||||
Stream: req.Stream,
|
||||
UserID: userID,
|
||||
OutputFormat: outputFormat,
|
||||
ParallelToolUse: parallelToolUse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeSystem 解码系统消息
|
||||
func decodeSystem(system any) any {
|
||||
if system == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
case []any:
|
||||
var blocks []canonical.SystemBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok {
|
||||
blocks = append(blocks, canonical.SystemBlock{Text: text})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
return nil
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeMessage 解码 Anthropic 消息
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var toolResults []canonical.ContentBlock
|
||||
var others []canonical.ContentBlock
|
||||
for _, b := range blocks {
|
||||
if b.Type == "tool_result" {
|
||||
toolResults = append(toolResults, b)
|
||||
} else {
|
||||
others = append(others, b)
|
||||
}
|
||||
}
|
||||
var result []canonical.CanonicalMessage
|
||||
if len(others) > 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: others})
|
||||
}
|
||||
if len(toolResults) > 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleTool, Content: toolResults})
|
||||
}
|
||||
if len(result) == 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||
}
|
||||
return result, nil
|
||||
|
||||
case "assistant":
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeContentBlocks 解码内容块列表
|
||||
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
block, err := decodeSingleContentBlock(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if block != nil {
|
||||
blocks = append(blocks, *block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks, nil
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSingleContentBlock 解码单个内容块
|
||||
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "text":
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}, nil
|
||||
case "tool_use":
|
||||
id, ok := m["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
name, ok := m["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
input, err := json.Marshal(m["input"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
|
||||
case "tool_result":
|
||||
toolUseID, ok := m["tool_use_id"].(string)
|
||||
if !ok {
|
||||
toolUseID = ""
|
||||
}
|
||||
isErr := false
|
||||
if ie, ok := m["is_error"].(bool); ok {
|
||||
isErr = ie
|
||||
}
|
||||
var content json.RawMessage
|
||||
if c, ok := m["content"]; ok {
|
||||
switch cv := c.(type) {
|
||||
case string:
|
||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||
default:
|
||||
encoded, err := json.Marshal(cv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = encoded
|
||||
}
|
||||
} else {
|
||||
content = json.RawMessage(`""`)
|
||||
}
|
||||
return &canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: content,
|
||||
IsError: &isErr,
|
||||
}, nil
|
||||
case "thinking":
|
||||
thinking, ok := m["thinking"].(string)
|
||||
if !ok {
|
||||
thinking = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
func decodeTools(tools []Tool) []canonical.CanonicalTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]canonical.CanonicalTool, len(tools))
|
||||
for i, t := range tools {
|
||||
result[i] = canonical.CanonicalTool{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: t.InputSchema,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeToolChoice 解码工具选择
|
||||
func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
if toolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := toolChoice.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
case "tool":
|
||||
name, ok := v["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeParameters 解码请求参数
|
||||
func decodeParameters(req *MessagesRequest) canonical.RequestParameters {
|
||||
params := canonical.RequestParameters{
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
val := req.MaxTokens
|
||||
params.MaxTokens = &val
|
||||
}
|
||||
if len(req.StopSequences) > 0 {
|
||||
params.StopSequences = req.StopSequences
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// decodeThinking 解码思考配置
|
||||
func decodeThinking(thinking *ThinkingConfig, outputConfig *OutputConfig) *canonical.ThinkingConfig {
|
||||
if thinking == nil {
|
||||
return nil
|
||||
}
|
||||
cfg := &canonical.ThinkingConfig{
|
||||
Type: thinking.Type,
|
||||
BudgetTokens: thinking.BudgetTokens,
|
||||
}
|
||||
if outputConfig != nil && outputConfig.Effort != "" {
|
||||
cfg.Effort = outputConfig.Effort
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// decodeOutputFormat 解码输出格式
|
||||
func decodeOutputFormat(outputConfig *OutputConfig) *canonical.OutputFormat {
|
||||
if outputConfig == nil || outputConfig.Format == nil {
|
||||
return nil
|
||||
}
|
||||
if outputConfig.Format.Type == "json_schema" && outputConfig.Format.Schema != nil {
|
||||
return &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: "output",
|
||||
Schema: outputConfig.Format.Schema,
|
||||
Strict: boolPtr(true),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeResponse 将 Anthropic 响应解码为 Canonical 响应
|
||||
func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) {
|
||||
var resp MessagesResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 Anthropic 响应失败").WithCause(err)
|
||||
}
|
||||
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
blocks = append(blocks, canonical.NewTextBlock(block.Text))
|
||||
case "tool_use":
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(block.ID, block.Name, block.Input))
|
||||
case "thinking":
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(block.Thinking))
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
}
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
|
||||
sr := mapStopReason(resp.StopReason)
|
||||
usage := canonical.CanonicalUsage{
|
||||
InputTokens: resp.Usage.InputTokens,
|
||||
OutputTokens: resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.CacheReadInputTokens != nil {
|
||||
usage.CacheReadTokens = resp.Usage.CacheReadInputTokens
|
||||
}
|
||||
if resp.Usage.CacheCreationInputTokens != nil {
|
||||
usage.CacheCreationTokens = resp.Usage.CacheCreationInputTokens
|
||||
}
|
||||
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: blocks,
|
||||
StopReason: &sr,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mapStopReason 映射停止原因
|
||||
func mapStopReason(reason string) canonical.StopReason {
|
||||
switch reason {
|
||||
case "end_turn":
|
||||
return canonical.StopReasonEndTurn
|
||||
case "max_tokens":
|
||||
return canonical.StopReasonMaxTokens
|
||||
case "tool_use":
|
||||
return canonical.StopReasonToolUse
|
||||
case "stop_sequence":
|
||||
return canonical.StopReasonStopSequence
|
||||
case "pause_turn":
|
||||
return canonical.StopReason("pause_turn")
|
||||
case "refusal":
|
||||
return canonical.StopReasonRefusal
|
||||
default:
|
||||
return canonical.StopReasonEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
// decodeModelsResponse 解码模型列表响应
|
||||
func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) {
|
||||
var resp ModelsResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
models := make([]canonical.CanonicalModel, len(resp.Data))
|
||||
for i, m := range resp.Data {
|
||||
name := m.DisplayName
|
||||
if name == "" {
|
||||
name = m.ID
|
||||
}
|
||||
models[i] = canonical.CanonicalModel{
|
||||
ID: m.ID,
|
||||
Name: name,
|
||||
Created: parseTimestamp(m.CreatedAt),
|
||||
OwnedBy: "anthropic",
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalModelList{Models: models}, nil
|
||||
}
|
||||
|
||||
// decodeModelInfoResponse 解码模型详情响应
|
||||
func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
var resp ModelInfoResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := resp.DisplayName
|
||||
if name == "" {
|
||||
name = resp.ID
|
||||
}
|
||||
return &canonical.CanonicalModelInfo{
|
||||
ID: resp.ID,
|
||||
Name: name,
|
||||
Created: parseTimestamp(resp.CreatedAt),
|
||||
OwnedBy: "anthropic",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseTimestamp 解析 RFC 3339 时间戳为 Unix
|
||||
func parseTimestamp(s string) int64 {
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// formatTimestamp 将 Unix 时间戳格式化为 RFC 3339
|
||||
func formatTimestamp(unix int64) string {
|
||||
if unix == 0 {
|
||||
return time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339)
|
||||
}
|
||||
return time.Unix(unix, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// boolPtr 返回 bool 指针
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
331
backend/internal/conversion/anthropic/decoder_test.go
Normal file
331
backend/internal/conversion/anthropic/decoder_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeRequest_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3", req.Model)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.NotNil(t, req.Parameters.MaxTokens)
|
||||
assert.Equal(t, 1024, *req.Parameters.MaxTokens)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_System(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"system": "你是助手",
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "你是助手", req.System)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_SystemBlocks(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"system": [{"text": "指令1"}, {"text": "指令2"}],
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
blocks, ok := req.System.([]canonical.SystemBlock)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, blocks, 2)
|
||||
assert.Equal(t, "指令1", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolResultSplit(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "查询天气"},
|
||||
{"type": "tool_result", "tool_use_id": "tool_1", "content": "晴天"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
// 用户消息中的 tool_result 应被拆分为独立的 tool 消息
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.Equal(t, canonical.RoleTool, req.Messages[1].Role)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingModel(t *testing.T) {
|
||||
body := []byte(`{"max_tokens": 1024, "messages": [{"role": "user", "content": "hi"}]}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingMessages(t *testing.T) {
|
||||
body := []byte(`{"model": "claude-3", "max_tokens": 1024}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "你好"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "msg_123", resp.ID)
|
||||
assert.Equal(t, "claude-3", resp.Model)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "你好", resp.Content[0].Text)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason)
|
||||
assert.Equal(t, 10, resp.Usage.InputTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_456",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "思考过程"},
|
||||
{"type": "text", "text": "回答"}
|
||||
],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 2)
|
||||
assert.Equal(t, "thinking", resp.Content[0].Type)
|
||||
assert.Equal(t, "思考过程", resp.Content[0].Thinking)
|
||||
assert.Equal(t, "text", resp.Content[1].Type)
|
||||
assert.Equal(t, "回答", resp.Content[1].Text)
|
||||
}
|
||||
|
||||
func TestDecodeModelsResponse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"data": [
|
||||
{"id": "claude-3-opus", "type": "model", "display_name": "Claude 3 Opus", "created_at": "2024-01-15T00:00:00Z"},
|
||||
{"id": "claude-3-sonnet", "type": "model", "created_at": "2024-02-01T00:00:00Z"}
|
||||
],
|
||||
"has_more": false
|
||||
}`)
|
||||
|
||||
list, err := decodeModelsResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list.Models, 2)
|
||||
assert.Equal(t, "claude-3-opus", list.Models[0].ID)
|
||||
assert.Equal(t, "Claude 3 Opus", list.Models[0].Name)
|
||||
// created_at RFC3339 → Unix
|
||||
assert.NotEqual(t, int64(0), list.Models[0].Created)
|
||||
// 无 display_name 时使用 ID
|
||||
assert.Equal(t, "claude-3-sonnet", list.Models[1].Name)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRequest([]byte(`invalid json`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "JSON_PARSE_ERROR")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 5000}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "enabled", req.Thinking.Type)
|
||||
require.NotNil(t, req.Thinking.BudgetTokens)
|
||||
assert.Equal(t, 5000, *req.Thinking.BudgetTokens)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ThinkingAdaptive(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"thinking": {"type": "adaptive"}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "adaptive", req.Thinking.Type)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputConfig(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"output_config": {
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_schema", req.OutputFormat.Type)
|
||||
assert.NotNil(t, req.OutputFormat.Schema)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_DisableParallelToolUse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"disable_parallel_tool_use": true
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.ParallelToolUse)
|
||||
assert.False(t, *req.ParallelToolUse)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_ToolUse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_tool",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "tool_1", "name": "search", "input": {"q": "test"}}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "tool_use", resp.Content[0].Type)
|
||||
assert.Equal(t, "tool_1", resp.Content[0].ID)
|
||||
assert.Equal(t, "search", resp.Content[0].Name)
|
||||
assert.NotNil(t, resp.Content[0].Input)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_RedactedThinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_redacted",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [
|
||||
{"type": "redacted_thinking", "data": "..."},
|
||||
{"type": "text", "text": "回答"}
|
||||
],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "text", resp.Content[0].Type)
|
||||
assert.Equal(t, "回答", resp.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason string
|
||||
want canonical.StopReason
|
||||
}{
|
||||
{"end_turn→end_turn", "end_turn", canonical.StopReasonEndTurn},
|
||||
{"max_tokens→max_tokens", "max_tokens", canonical.StopReasonMaxTokens},
|
||||
{"tool_use→tool_use", "tool_use", canonical.StopReasonToolUse},
|
||||
{"stop_sequence→stop_sequence", "stop_sequence", canonical.StopReasonStopSequence},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"id": "msg-1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "%s",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1}
|
||||
}`, tt.reason))
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, tt.want, *resp.StopReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Usage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg_usage",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 50, resp.Usage.OutputTokens)
|
||||
require.NotNil(t, resp.Usage.CacheReadTokens)
|
||||
assert.Equal(t, 30, *resp.Usage.CacheReadTokens)
|
||||
}
|
||||
449
backend/internal/conversion/anthropic/encoder.go
Normal file
449
backend/internal/conversion/anthropic/encoder.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// encodeRequest 将 Canonical 请求编码为 Anthropic 请求
|
||||
func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"stream": req.Stream,
|
||||
}
|
||||
|
||||
// max_tokens 必填
|
||||
if req.Parameters.MaxTokens != nil {
|
||||
result["max_tokens"] = *req.Parameters.MaxTokens
|
||||
} else {
|
||||
result["max_tokens"] = 4096
|
||||
}
|
||||
|
||||
// 系统消息
|
||||
if req.System != nil {
|
||||
result["system"] = encodeSystem(req.System)
|
||||
}
|
||||
|
||||
// 消息
|
||||
result["messages"] = encodeMessages(req.Messages)
|
||||
|
||||
// 参数
|
||||
if req.Parameters.Temperature != nil {
|
||||
result["temperature"] = *req.Parameters.Temperature
|
||||
}
|
||||
if req.Parameters.TopP != nil {
|
||||
result["top_p"] = *req.Parameters.TopP
|
||||
}
|
||||
if req.Parameters.TopK != nil {
|
||||
result["top_k"] = *req.Parameters.TopK
|
||||
}
|
||||
if len(req.Parameters.StopSequences) > 0 {
|
||||
result["stop_sequences"] = req.Parameters.StopSequences
|
||||
}
|
||||
|
||||
// 工具
|
||||
if len(req.Tools) > 0 {
|
||||
tools := make([]map[string]any, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
tool := map[string]any{
|
||||
"name": t.Name,
|
||||
"input_schema": t.InputSchema,
|
||||
}
|
||||
if t.Description != "" {
|
||||
tool["description"] = t.Description
|
||||
}
|
||||
tools[i] = tool
|
||||
}
|
||||
result["tools"] = tools
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
result["tool_choice"] = encodeToolChoice(req.ToolChoice)
|
||||
}
|
||||
|
||||
// 公共字段
|
||||
if req.UserID != "" {
|
||||
result["metadata"] = map[string]any{"user_id": req.UserID}
|
||||
}
|
||||
if req.ParallelToolUse != nil && !*req.ParallelToolUse {
|
||||
result["disable_parallel_tool_use"] = true
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
result["thinking"] = encodeThinkingConfig(req.Thinking)
|
||||
}
|
||||
|
||||
// output_config
|
||||
outputConfig := map[string]any{}
|
||||
hasOutputConfig := false
|
||||
if req.OutputFormat != nil {
|
||||
of := encodeOutputFormat(req.OutputFormat)
|
||||
if of != nil {
|
||||
outputConfig["format"] = of
|
||||
hasOutputConfig = true
|
||||
}
|
||||
}
|
||||
if req.Thinking != nil && req.Thinking.Effort != "" {
|
||||
outputConfig["effort"] = req.Thinking.Effort
|
||||
hasOutputConfig = true
|
||||
}
|
||||
if hasOutputConfig {
|
||||
result["output_config"] = outputConfig
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 请求失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// encodeSystem 编码系统消息
|
||||
func encodeSystem(system any) any {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []canonical.SystemBlock:
|
||||
blocks := make([]map[string]any, len(v))
|
||||
for i, b := range v {
|
||||
blocks[i] = map[string]any{"text": b.Text}
|
||||
}
|
||||
return blocks
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// encodeMessages 编码消息列表(含角色约束处理)
|
||||
func encodeMessages(msgs []canonical.CanonicalMessage) []map[string]any {
|
||||
var result []map[string]any
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch msg.Role {
|
||||
case canonical.RoleUser:
|
||||
result = append(result, map[string]any{
|
||||
"role": "user",
|
||||
"content": encodeContentBlocks(msg.Content),
|
||||
})
|
||||
case canonical.RoleAssistant:
|
||||
result = append(result, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": encodeContentBlocks(msg.Content),
|
||||
})
|
||||
case canonical.RoleTool:
|
||||
// tool 角色合并到相邻 user 消息
|
||||
toolResults := filterToolResults(msg.Content)
|
||||
if len(result) > 0 && result[len(result)-1]["role"] == "user" {
|
||||
// 合并到最后一条 user 消息
|
||||
lastContent, ok := result[len(result)-1]["content"].([]map[string]any)
|
||||
if ok {
|
||||
result[len(result)-1]["content"] = append(lastContent, toolResults...)
|
||||
} else {
|
||||
result[len(result)-1]["content"] = toolResults
|
||||
}
|
||||
} else {
|
||||
result = append(result, map[string]any{
|
||||
"role": "user",
|
||||
"content": toolResults,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 确保首消息为 user
|
||||
if len(result) > 0 && result[0]["role"] != "user" {
|
||||
result = append([]map[string]any{{"role": "user", "content": []map[string]any{}}}, result...)
|
||||
}
|
||||
|
||||
// 合并连续同角色消息
|
||||
result = mergeConsecutiveRoles(result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeContentBlocks 编码内容块列表
|
||||
func encodeContentBlocks(blocks []canonical.ContentBlock) []map[string]any {
|
||||
result := make([]map[string]any, 0, len(blocks))
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
result = append(result, map[string]any{"type": "text", "text": b.Text})
|
||||
case "tool_use":
|
||||
m := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": b.ID,
|
||||
"name": b.Name,
|
||||
"input": b.Input,
|
||||
}
|
||||
if b.Input == nil {
|
||||
m["input"] = map[string]any{}
|
||||
}
|
||||
result = append(result, m)
|
||||
case "tool_result":
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
var contentStr string
|
||||
if json.Unmarshal(b.Content, &contentStr) == nil {
|
||||
m["content"] = contentStr
|
||||
} else {
|
||||
m["content"] = string(b.Content)
|
||||
}
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
if b.IsError != nil {
|
||||
m["is_error"] = *b.IsError
|
||||
}
|
||||
result = append(result, m)
|
||||
case "thinking":
|
||||
result = append(result, map[string]any{"type": "thinking", "thinking": b.Thinking})
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return []map[string]any{{"type": "text", "text": ""}}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// filterToolResults 过滤工具结果
|
||||
func filterToolResults(blocks []canonical.ContentBlock) []map[string]any {
|
||||
var result []map[string]any
|
||||
for _, b := range blocks {
|
||||
if b.Type == "tool_result" {
|
||||
m := map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b.ToolUseID,
|
||||
}
|
||||
if b.Content != nil {
|
||||
var contentStr string
|
||||
if json.Unmarshal(b.Content, &contentStr) == nil {
|
||||
m["content"] = contentStr
|
||||
} else {
|
||||
m["content"] = string(b.Content)
|
||||
}
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
if b.IsError != nil {
|
||||
m["is_error"] = *b.IsError
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeToolChoice 编码工具选择
|
||||
func encodeToolChoice(choice *canonical.ToolChoice) any {
|
||||
switch choice.Type {
|
||||
case "auto":
|
||||
return map[string]any{"type": "auto"}
|
||||
case "none":
|
||||
return map[string]any{"type": "none"}
|
||||
case "any":
|
||||
return map[string]any{"type": "any"}
|
||||
case "tool":
|
||||
return map[string]any{"type": "tool", "name": choice.Name}
|
||||
}
|
||||
return map[string]any{"type": "auto"}
|
||||
}
|
||||
|
||||
// encodeThinkingConfig 编码思考配置
|
||||
func encodeThinkingConfig(cfg *canonical.ThinkingConfig) map[string]any {
|
||||
switch cfg.Type {
|
||||
case "enabled":
|
||||
m := map[string]any{"type": "enabled"}
|
||||
if cfg.BudgetTokens != nil {
|
||||
m["budget_tokens"] = *cfg.BudgetTokens
|
||||
}
|
||||
return m
|
||||
case "disabled":
|
||||
return map[string]any{"type": "disabled"}
|
||||
case "adaptive":
|
||||
return map[string]any{"type": "adaptive"}
|
||||
}
|
||||
return map[string]any{"type": "disabled"}
|
||||
}
|
||||
|
||||
// encodeOutputFormat 编码输出格式
|
||||
func encodeOutputFormat(format *canonical.OutputFormat) map[string]any {
|
||||
if format == nil {
|
||||
return nil
|
||||
}
|
||||
switch format.Type {
|
||||
case "json_schema":
|
||||
schema := format.Schema
|
||||
if schema == nil {
|
||||
schema = json.RawMessage(`{"type":"object"}`)
|
||||
}
|
||||
return map[string]any{
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
case "json_object":
|
||||
return map[string]any{
|
||||
"type": "json_schema",
|
||||
"schema": map[string]any{"type": "object"},
|
||||
}
|
||||
case "text":
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeResponse 将 Canonical 响应编码为 Anthropic 响应
|
||||
func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
blocks := make([]map[string]any, 0, len(resp.Content))
|
||||
for _, b := range resp.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
blocks = append(blocks, map[string]any{"type": "text", "text": b.Text})
|
||||
case "tool_use":
|
||||
m := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": b.ID,
|
||||
"name": b.Name,
|
||||
"input": b.Input,
|
||||
}
|
||||
if b.Input == nil {
|
||||
m["input"] = map[string]any{}
|
||||
}
|
||||
blocks = append(blocks, m)
|
||||
case "thinking":
|
||||
blocks = append(blocks, map[string]any{"type": "thinking", "thinking": b.Thinking})
|
||||
}
|
||||
}
|
||||
|
||||
sr := "end_turn"
|
||||
if resp.StopReason != nil {
|
||||
sr = mapCanonicalStopReason(*resp.StopReason)
|
||||
}
|
||||
|
||||
usage := map[string]any{
|
||||
"input_tokens": resp.Usage.InputTokens,
|
||||
"output_tokens": resp.Usage.OutputTokens,
|
||||
}
|
||||
if resp.Usage.CacheReadTokens != nil {
|
||||
usage["cache_read_input_tokens"] = *resp.Usage.CacheReadTokens
|
||||
}
|
||||
if resp.Usage.CacheCreationTokens != nil {
|
||||
usage["cache_creation_input_tokens"] = *resp.Usage.CacheCreationTokens
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": resp.Model,
|
||||
"content": blocks,
|
||||
"stop_reason": sr,
|
||||
"stop_sequence": nil,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 Anthropic 响应失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapCanonicalStopReason 映射 Canonical 停止原因到 Anthropic
|
||||
func mapCanonicalStopReason(reason canonical.StopReason) string {
|
||||
switch reason {
|
||||
case canonical.StopReasonEndTurn, canonical.StopReasonContentFilter:
|
||||
return "end_turn"
|
||||
case canonical.StopReasonMaxTokens:
|
||||
return "max_tokens"
|
||||
case canonical.StopReasonToolUse:
|
||||
return "tool_use"
|
||||
case canonical.StopReasonStopSequence:
|
||||
return "stop_sequence"
|
||||
case canonical.StopReasonRefusal:
|
||||
return "refusal"
|
||||
default:
|
||||
return "end_turn"
|
||||
}
|
||||
}
|
||||
|
||||
// encodeModelsResponse 编码模型列表响应
|
||||
func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
data := make([]map[string]any, len(list.Models))
|
||||
for i, m := range list.Models {
|
||||
name := m.Name
|
||||
if name == "" {
|
||||
name = m.ID
|
||||
}
|
||||
data[i] = map[string]any{
|
||||
"id": m.ID,
|
||||
"type": "model",
|
||||
"display_name": name,
|
||||
"created_at": formatTimestamp(m.Created),
|
||||
}
|
||||
}
|
||||
|
||||
var firstID, lastID *string
|
||||
if len(list.Models) > 0 {
|
||||
fid := list.Models[0].ID
|
||||
firstID = &fid
|
||||
lid := list.Models[len(list.Models)-1].ID
|
||||
lastID = &lid
|
||||
}
|
||||
|
||||
return json.Marshal(map[string]any{
|
||||
"data": data,
|
||||
"has_more": false,
|
||||
"first_id": firstID,
|
||||
"last_id": lastID,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeModelInfoResponse 编码模型详情响应
|
||||
func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
name := info.Name
|
||||
if name == "" {
|
||||
name = info.ID
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"id": info.ID,
|
||||
"type": "model",
|
||||
"display_name": name,
|
||||
"created_at": formatTimestamp(info.Created),
|
||||
})
|
||||
}
|
||||
|
||||
// mergeConsecutiveRoles 合并连续同角色消息
|
||||
func mergeConsecutiveRoles(messages []map[string]any) []map[string]any {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
var result []map[string]any
|
||||
for _, msg := range messages {
|
||||
if len(result) > 0 {
|
||||
lastRole := result[len(result)-1]["role"]
|
||||
currRole := msg["role"]
|
||||
if lastRole == currRole {
|
||||
// 合并 content
|
||||
lastContent := result[len(result)-1]["content"]
|
||||
currContent := msg["content"]
|
||||
switch lv := lastContent.(type) {
|
||||
case []map[string]any:
|
||||
if cv, ok := currContent.([]map[string]any); ok {
|
||||
result[len(result)-1]["content"] = append(lv, cv...)
|
||||
}
|
||||
case string:
|
||||
if cv, ok := currContent.(string); ok {
|
||||
result[len(result)-1]["content"] = lv + cv
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
366
backend/internal/conversion/anthropic/encoder_test.go
Normal file
366
backend/internal/conversion/anthropic/encoder_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeRequest_Basic(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Stream: true,
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-model", result["model"])
|
||||
assert.Equal(t, true, result["stream"])
|
||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("查询")}},
|
||||
{Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", json.RawMessage(`{"q":"test"}`))}},
|
||||
{Role: canonical.RoleTool, Content: []canonical.ContentBlock{canonical.NewToolResultBlock("tool_1", "结果", false)}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
|
||||
// tool 消息应被合并到相邻 user 消息
|
||||
foundToolResult := false
|
||||
for _, m := range msgs {
|
||||
msgMap, ok := m.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if msgMap["role"] == "user" {
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if ok {
|
||||
for _, c := range content {
|
||||
block, ok := c.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if block["type"] == "tool_result" {
|
||||
foundToolResult = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, foundToolResult)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleAssistant, Content: []canonical.ContentBlock{canonical.NewTextBlock("前置")}},
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", firstMsg["role"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingEnabled(t *testing.T) {
|
||||
budget := 10000
|
||||
maxTokens := 8096
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "enabled", BudgetTokens: &budget},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
thinking, ok := result["thinking"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "enabled", thinking["type"])
|
||||
assert.Equal(t, float64(10000), thinking["budget_tokens"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg_1",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "msg_1", result["id"])
|
||||
assert.Equal(t, "message", result["type"])
|
||||
assert.Equal(t, "assistant", result["role"])
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", block["type"])
|
||||
assert.Equal(t, "你好", block["text"])
|
||||
}
|
||||
|
||||
func TestEncodeModelsResponse(t *testing.T) {
|
||||
ts := time.Date(2024, 3, 15, 0, 0, 0, 0, time.UTC).Unix()
|
||||
list := &canonical.CanonicalModelList{
|
||||
Models: []canonical.CanonicalModel{
|
||||
{ID: "claude-3-opus", Name: "Claude 3 Opus", Created: ts, OwnedBy: "anthropic"},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeModelsResponse(list)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
data, ok := result["data"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, data, 1)
|
||||
|
||||
model, ok := data[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude-3-opus", model["id"])
|
||||
// created 应为 RFC3339 格式
|
||||
createdAt, ok := model["created_at"].(string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, createdAt, "2024")
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingDisabled(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
_, hasThinking := result["thinking"]
|
||||
assert.False(t, hasThinking)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingAdaptive(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "adaptive"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
thinking, ok := result["thinking"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "adaptive", thinking["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Schema: schema,
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
oc, ok := result["output_config"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
format, ok := oc["format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", format["type"])
|
||||
assert.NotNil(t, format["schema"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSON(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_object",
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
oc, ok := result["output_config"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
format, ok := oc["format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", format["type"])
|
||||
schemaMap, ok := format["schema"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "object", schemaMap["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("A")}},
|
||||
{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("B")}},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
userMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", userMsg["role"])
|
||||
content, ok := userMsg["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ContentFilter(t *testing.T) {
|
||||
sr := canonical.StopReasonContentFilter
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-cf",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
||||
reasoning := 100
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-rt",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5, ReasoningTokens: &reasoning},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage, ok := result["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasReasoning := usage["reasoning_tokens"]
|
||||
assert.False(t, hasReasoning)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-tool",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("tool_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", block["type"])
|
||||
assert.Equal(t, "tool_1", block["id"])
|
||||
assert.Equal(t, "search", block["name"])
|
||||
}
|
||||
284
backend/internal/conversion/anthropic/stream_decoder.go
Normal file
284
backend/internal/conversion/anthropic/stream_decoder.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder Anthropic 流式解码器
|
||||
type StreamDecoder struct {
|
||||
messageStarted bool
|
||||
redactedBlocks map[int]bool
|
||||
utf8Remainder []byte
|
||||
accumulatedUsage *canonical.CanonicalUsage
|
||||
}
|
||||
|
||||
// NewStreamDecoder 创建 Anthropic 流式解码器
|
||||
func NewStreamDecoder() *StreamDecoder {
|
||||
return &StreamDecoder{
|
||||
redactedBlocks: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 处理原始 SSE chunk
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
if !utf8.Valid(data) {
|
||||
validEnd := len(data)
|
||||
for !utf8.Valid(data[:validEnd]) {
|
||||
validEnd--
|
||||
}
|
||||
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
|
||||
data = data[:validEnd]
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
text := string(data)
|
||||
|
||||
// 解析命名 SSE 事件
|
||||
var eventType string
|
||||
var eventData string
|
||||
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
switch {
|
||||
case strings.HasPrefix(line, "event: "):
|
||||
eventType = strings.TrimPrefix(line, "event: ")
|
||||
case strings.HasPrefix(line, "data: "):
|
||||
eventData = strings.TrimPrefix(line, "data: ")
|
||||
if eventType != "" && eventData != "" {
|
||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||
events = append(events, chunkEvents...)
|
||||
}
|
||||
eventType = ""
|
||||
eventData = ""
|
||||
case line == "":
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 刷新解码器状态
|
||||
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processEvent 处理单个命名 SSE 事件
|
||||
func (d *StreamDecoder) processEvent(eventType string, data []byte) []canonical.CanonicalStreamEvent {
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
return d.processMessageStart(data)
|
||||
case "content_block_start":
|
||||
return d.processContentBlockStart(data)
|
||||
case "content_block_delta":
|
||||
return d.processContentBlockDelta(data)
|
||||
case "content_block_stop":
|
||||
return d.processContentBlockStop(data)
|
||||
case "message_delta":
|
||||
return d.processMessageDelta(data)
|
||||
case "message_stop":
|
||||
return d.processMessageStop(data)
|
||||
case "ping":
|
||||
return []canonical.CanonicalStreamEvent{canonical.NewPingEvent()}
|
||||
case "error":
|
||||
return d.processError(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processMessageStart 处理消息开始事件
|
||||
func (d *StreamDecoder) processMessageStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage *struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if msgRaw, ok := raw["message"]; ok {
|
||||
if err := json.Unmarshal(msgRaw, &msg); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
event := canonical.NewMessageStartEvent(msg.ID, msg.Model)
|
||||
if msg.Usage != nil {
|
||||
usage := &canonical.CanonicalUsage{
|
||||
InputTokens: msg.Usage.InputTokens,
|
||||
OutputTokens: msg.Usage.OutputTokens,
|
||||
}
|
||||
event = canonical.NewMessageStartEventWithUsage(msg.ID, msg.Model, usage)
|
||||
d.accumulatedUsage = usage
|
||||
}
|
||||
|
||||
d.messageStarted = true
|
||||
return []canonical.CanonicalStreamEvent{event}
|
||||
}
|
||||
|
||||
// processContentBlockStart 处理内容块开始事件
|
||||
func (d *StreamDecoder) processContentBlockStart(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Thinking string `json:"thinking"`
|
||||
Data string `json:"data"`
|
||||
} `json:"content_block"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查需要丢弃的块类型
|
||||
switch raw.ContentBlock.Type {
|
||||
case "redacted_thinking", "server_tool_use", "web_search_tool_result",
|
||||
"code_execution_tool_result":
|
||||
d.redactedBlocks[raw.Index] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.redactedBlocks[raw.Index] {
|
||||
return nil
|
||||
}
|
||||
|
||||
block := canonical.StreamContentBlock{
|
||||
Type: raw.ContentBlock.Type,
|
||||
Text: raw.ContentBlock.Text,
|
||||
ID: raw.ContentBlock.ID,
|
||||
Name: raw.ContentBlock.Name,
|
||||
Thinking: raw.ContentBlock.Thinking,
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockStartEvent(raw.Index, block),
|
||||
}
|
||||
}
|
||||
|
||||
// processContentBlockDelta 处理内容块增量事件
|
||||
func (d *StreamDecoder) processContentBlockDelta(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
PartialJSON string `json:"partial_json"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"delta"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否在丢弃的块中
|
||||
if d.redactedBlocks[raw.Index] {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 丢弃协议特有 delta 类型
|
||||
switch raw.Delta.Type {
|
||||
case "citations_delta", "signature_delta":
|
||||
return nil
|
||||
}
|
||||
|
||||
delta := canonical.StreamDelta{
|
||||
Type: raw.Delta.Type,
|
||||
Text: raw.Delta.Text,
|
||||
PartialJSON: raw.Delta.PartialJSON,
|
||||
Thinking: raw.Delta.Thinking,
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockDeltaEvent(raw.Index, delta),
|
||||
}
|
||||
}
|
||||
|
||||
// processContentBlockStop 处理内容块结束事件
|
||||
func (d *StreamDecoder) processContentBlockStop(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Index int `json:"index"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, redacted := d.redactedBlocks[raw.Index]; redacted {
|
||||
delete(d.redactedBlocks, raw.Index)
|
||||
return nil
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewContentBlockStopEvent(raw.Index),
|
||||
}
|
||||
}
|
||||
|
||||
// processMessageDelta 处理消息增量事件
|
||||
func (d *StreamDecoder) processMessageDelta(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Delta struct {
|
||||
StopReason string `json:"stop_reason"`
|
||||
} `json:"delta"`
|
||||
Usage struct {
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sr := mapStopReason(raw.Delta.StopReason)
|
||||
usage := &canonical.CanonicalUsage{
|
||||
OutputTokens: raw.Usage.OutputTokens,
|
||||
}
|
||||
|
||||
if d.accumulatedUsage != nil {
|
||||
d.accumulatedUsage.OutputTokens = raw.Usage.OutputTokens
|
||||
}
|
||||
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageDeltaEventWithUsage(sr, usage),
|
||||
}
|
||||
}
|
||||
|
||||
// processMessageStop 处理消息结束事件
|
||||
func (d *StreamDecoder) processMessageStop(data []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{canonical.NewMessageStopEvent()}
|
||||
}
|
||||
|
||||
// processError 处理错误事件
|
||||
func (d *StreamDecoder) processError(data []byte) []canonical.CanonicalStreamEvent {
|
||||
var raw struct {
|
||||
Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewErrorEvent("stream_error", fmt.Sprintf("解析错误事件失败: %s", string(data))),
|
||||
}
|
||||
}
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewErrorEvent(raw.Error.Type, raw.Error.Message),
|
||||
}
|
||||
}
|
||||
489
backend/internal/conversion/anthropic/stream_decoder_test.go
Normal file
489
backend/internal/conversion/anthropic/stream_decoder_test.go
Normal file
@@ -0,0 +1,489 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeAnthropicEvent(eventType string, data any) []byte {
|
||||
dataBytes, _ := json.Marshal(data)
|
||||
return []byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, string(dataBytes)))
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageStart(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_1",
|
||||
"model": "claude-3",
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("message_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
assert.Equal(t, canonical.EventMessageStart, events[0].Type)
|
||||
assert.Equal(t, "msg_1", events[0].Message.ID)
|
||||
assert.Equal(t, "claude-3", events[0].Message.Model)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockDelta(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
deltaType string
|
||||
deltaData map[string]any
|
||||
checkField string
|
||||
checkValue string
|
||||
}{
|
||||
{
|
||||
name: "text_delta",
|
||||
deltaType: "text_delta",
|
||||
deltaData: map[string]any{"type": "text_delta", "text": "你好"},
|
||||
checkField: "text",
|
||||
checkValue: "你好",
|
||||
},
|
||||
{
|
||||
name: "input_json_delta",
|
||||
deltaType: "input_json_delta",
|
||||
deltaData: map[string]any{"type": "input_json_delta", "partial_json": "{\"key\":"},
|
||||
checkField: "partial_json",
|
||||
checkValue: "{\"key\":",
|
||||
},
|
||||
{
|
||||
name: "thinking_delta",
|
||||
deltaType: "thinking_delta",
|
||||
deltaData: map[string]any{"type": "thinking_delta", "thinking": "思考中"},
|
||||
checkField: "thinking",
|
||||
checkValue: "思考中",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": tt.deltaData,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
assert.Equal(t, canonical.EventContentBlockDelta, events[0].Type)
|
||||
assert.NotNil(t, events[0].Delta)
|
||||
|
||||
switch tt.checkField {
|
||||
case "text":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.Text)
|
||||
case "partial_json":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.PartialJSON)
|
||||
case "thinking":
|
||||
assert.Equal(t, tt.checkValue, events[0].Delta.Thinking)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedThinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// redacted_thinking block start 应被抑制
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]any{
|
||||
"type": "redacted_thinking",
|
||||
"data": "redacted-data",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
assert.True(t, d.redactedBlocks[1])
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedBlockStopSuppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
d.redactedBlocks[2] = true
|
||||
|
||||
// content_block_stop 对 redacted block 返回 nil
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": 2,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_stop", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
// 应清理 redactedBlocks
|
||||
_, exists := d.redactedBlocks[2]
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStart(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
|
||||
require.NotNil(t, events[0].ContentBlock)
|
||||
assert.Equal(t, "text", events[0].ContentBlock.Type)
|
||||
require.NotNil(t, events[0].Index)
|
||||
assert.Equal(t, 0, *events[0].Index)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_1",
|
||||
"name": "search",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
|
||||
require.NotNil(t, events[0].ContentBlock)
|
||||
assert.Equal(t, "tool_use", events[0].ContentBlock.Type)
|
||||
assert.Equal(t, "toolu_1", events[0].ContentBlock.ID)
|
||||
assert.Equal(t, "search", events[0].ContentBlock.Name)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ContentBlockStop(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_stop", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStop, events[0].Type)
|
||||
require.NotNil(t, events[0].Index)
|
||||
assert.Equal(t, 0, *events[0].Index)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageDelta(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": 42,
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("message_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventMessageDelta, events[0].Type)
|
||||
require.NotNil(t, events[0].StopReason)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *events[0].StopReason)
|
||||
require.NotNil(t, events[0].Usage)
|
||||
assert.Equal(t, 42, events[0].Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageStop(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := makeAnthropicEvent("message_stop", map[string]any{"type": "message_stop"})
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventMessageStop, events[0].Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Ping(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := makeAnthropicEvent("ping", map[string]any{"type": "ping"})
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventPing, events[0].Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Error(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": "overloaded_error",
|
||||
"message": "服务过载",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("error", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventError, events[0].Type)
|
||||
require.NotNil(t, events[0].Error)
|
||||
assert.Equal(t, "overloaded_error", events[0].Error.Type)
|
||||
assert.Equal(t, "服务过载", events[0].Error.Message)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RedactedDeltaSuppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
d.redactedBlocks[1] = true
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 1,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": "被抑制的内容",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ServerToolUse_Suppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": map[string]any{
|
||||
"type": "server_tool_use",
|
||||
"id": "server_tool_1",
|
||||
"name": "web_search",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
assert.True(t, d.redactedBlocks[2])
|
||||
}
|
||||
|
||||
func TestStreamDecoder_WebSearchToolResult_Suppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 3,
|
||||
"content_block": map[string]any{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "search_1",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
assert.True(t, d.redactedBlocks[3])
|
||||
}
|
||||
|
||||
func TestStreamDecoder_CodeExecutionToolResult_Suppressed(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 4,
|
||||
"content_block": map[string]any{
|
||||
"type": "code_execution_tool_result",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
assert.True(t, d.redactedBlocks[4])
|
||||
}
|
||||
|
||||
func TestStreamDecoder_CitationsDelta_Discarded(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "citations_delta",
|
||||
"citation": map[string]any{"title": "ref1"},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_SignatureDelta_Discarded(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "signature_delta",
|
||||
"signature": "sig_123",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_delta", payload)
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_UnknownEventType(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := makeAnthropicEvent("unknown_event", map[string]any{"type": "unknown_event"})
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_InvalidJSON(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := []byte("event: message_start\ndata: {invalid}\n\n")
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MultipleEventsInSingleChunk(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
startPayload := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_multi",
|
||||
"model": "claude-3",
|
||||
},
|
||||
}
|
||||
deltaPayload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": "Hello",
|
||||
},
|
||||
}
|
||||
stopPayload := map[string]any{"type": "message_stop"}
|
||||
|
||||
var raw []byte
|
||||
raw = append(raw, makeAnthropicEvent("message_start", startPayload)...)
|
||||
raw = append(raw, makeAnthropicEvent("content_block_delta", deltaPayload)...)
|
||||
raw = append(raw, makeAnthropicEvent("message_stop", stopPayload)...)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 3)
|
||||
assert.Equal(t, canonical.EventMessageStart, events[0].Type)
|
||||
assert.Equal(t, canonical.EventContentBlockDelta, events[1].Type)
|
||||
assert.Equal(t, canonical.EventMessageStop, events[2].Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ErrorInvalidJSON(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := []byte("event: error\ndata: {invalid}\n\n")
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventError, events[0].Type)
|
||||
assert.Contains(t, events[0].Error.Message, "解析错误事件失败")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageStartWithUsage(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_usage",
|
||||
"model": "claude-3",
|
||||
"usage": map[string]any{"input_tokens": 25, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("message_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventMessageStart, events[0].Type)
|
||||
require.NotNil(t, events[0].Message.Usage)
|
||||
assert.Equal(t, 25, events[0].Message.Usage.InputTokens)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ThinkingBlockStart(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]any{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
},
|
||||
}
|
||||
raw := makeAnthropicEvent("content_block_start", payload)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, canonical.EventContentBlockStart, events[0].Type)
|
||||
require.NotNil(t, events[0].ContentBlock)
|
||||
assert.Equal(t, "thinking", events[0].ContentBlock.Type)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MessageDelta_UsageNotAccumulated(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
startPayload := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_usage_test",
|
||||
"model": "claude-3",
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 0},
|
||||
},
|
||||
}
|
||||
deltaPayload1 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 25},
|
||||
}
|
||||
|
||||
d.ProcessChunk(makeAnthropicEvent("message_start", startPayload))
|
||||
events := d.ProcessChunk(makeAnthropicEvent("message_delta", deltaPayload1))
|
||||
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, 25, events[0].Usage.OutputTokens)
|
||||
|
||||
deltaPayload2 := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 30},
|
||||
}
|
||||
events = d.ProcessChunk(makeAnthropicEvent("message_delta", deltaPayload2))
|
||||
require.Len(t, events, 1)
|
||||
assert.Equal(t, 30, events[0].Usage.OutputTokens, "output_tokens should be replaced, not accumulated")
|
||||
assert.Equal(t, 30, d.accumulatedUsage.OutputTokens, "accumulated usage should match last value")
|
||||
}
|
||||
200
backend/internal/conversion/anthropic/stream_encoder.go
Normal file
200
backend/internal/conversion/anthropic/stream_encoder.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamEncoder Anthropic 流式编码器
|
||||
type StreamEncoder struct{}
|
||||
|
||||
// NewStreamEncoder 创建 Anthropic 流式编码器
|
||||
func NewStreamEncoder() *StreamEncoder {
|
||||
return &StreamEncoder{}
|
||||
}
|
||||
|
||||
// EncodeEvent 编码 Canonical 事件为 Anthropic 命名 SSE 事件
|
||||
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
switch event.Type {
|
||||
case canonical.EventMessageStart:
|
||||
return e.encodeMessageStart(event)
|
||||
case canonical.EventContentBlockStart:
|
||||
return e.encodeContentBlockStart(event)
|
||||
case canonical.EventContentBlockDelta:
|
||||
return e.encodeContentBlockDelta(event)
|
||||
case canonical.EventContentBlockStop:
|
||||
return e.encodeContentBlockStop(event)
|
||||
case canonical.EventMessageDelta:
|
||||
return e.encodeMessageDelta(event)
|
||||
case canonical.EventMessageStop:
|
||||
return e.encodeMessageStop(event)
|
||||
case canonical.EventPing:
|
||||
return e.encodePing()
|
||||
case canonical.EventError:
|
||||
return e.encodeError(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区(无缓冲)
|
||||
func (e *StreamEncoder) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeMessageStart 编码消息开始事件
|
||||
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
payload := map[string]any{
|
||||
"type": "message_start",
|
||||
}
|
||||
if event.Message != nil {
|
||||
msg := map[string]any{
|
||||
"id": event.Message.ID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{},
|
||||
"model": event.Message.Model,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
}
|
||||
if event.Message.Usage != nil {
|
||||
msg["usage"] = map[string]any{
|
||||
"input_tokens": event.Message.Usage.InputTokens,
|
||||
"output_tokens": event.Message.Usage.OutputTokens,
|
||||
}
|
||||
} else {
|
||||
msg["usage"] = map[string]any{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
}
|
||||
payload["message"] = msg
|
||||
}
|
||||
return e.marshalEvent("message_start", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockStart 编码内容块开始事件
|
||||
func (e *StreamEncoder) encodeContentBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.ContentBlock == nil || event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cb := map[string]any{
|
||||
"type": event.ContentBlock.Type,
|
||||
}
|
||||
switch event.ContentBlock.Type {
|
||||
case "text":
|
||||
cb["text"] = ""
|
||||
case "tool_use":
|
||||
cb["id"] = event.ContentBlock.ID
|
||||
cb["name"] = event.ContentBlock.Name
|
||||
cb["input"] = map[string]any{}
|
||||
case "thinking":
|
||||
cb["thinking"] = ""
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": *event.Index,
|
||||
"content_block": cb,
|
||||
}
|
||||
return e.marshalEvent("content_block_start", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockDelta 编码内容块增量事件
|
||||
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Delta == nil || event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
delta := map[string]any{
|
||||
"type": event.Delta.Type,
|
||||
}
|
||||
switch canonical.DeltaType(event.Delta.Type) {
|
||||
case canonical.DeltaTypeText:
|
||||
delta["text"] = event.Delta.Text
|
||||
case canonical.DeltaTypeInputJSON:
|
||||
delta["partial_json"] = event.Delta.PartialJSON
|
||||
case canonical.DeltaTypeThinking:
|
||||
delta["thinking"] = event.Delta.Thinking
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": *event.Index,
|
||||
"delta": delta,
|
||||
}
|
||||
return e.marshalEvent("content_block_delta", payload)
|
||||
}
|
||||
|
||||
// encodeContentBlockStop 编码内容块结束事件
|
||||
func (e *StreamEncoder) encodeContentBlockStop(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Index == nil {
|
||||
return nil
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": *event.Index,
|
||||
}
|
||||
return e.marshalEvent("content_block_stop", payload)
|
||||
}
|
||||
|
||||
// encodeMessageDelta 编码消息增量事件
|
||||
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{}
|
||||
if event.StopReason != nil {
|
||||
delta["stop_reason"] = mapCanonicalStopReason(*event.StopReason)
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": delta,
|
||||
}
|
||||
if event.Usage != nil {
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": event.Usage.OutputTokens,
|
||||
}
|
||||
} else {
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": 0,
|
||||
}
|
||||
}
|
||||
return e.marshalEvent("message_delta", payload)
|
||||
}
|
||||
|
||||
// encodeMessageStop 编码消息结束事件
|
||||
func (e *StreamEncoder) encodeMessageStop(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
payload := map[string]any{"type": "message_stop"}
|
||||
return e.marshalEvent("message_stop", payload)
|
||||
}
|
||||
|
||||
// encodePing 编码心跳事件
|
||||
func (e *StreamEncoder) encodePing() [][]byte {
|
||||
payload := map[string]any{"type": "ping"}
|
||||
return e.marshalEvent("ping", payload)
|
||||
}
|
||||
|
||||
// encodeError 编码错误事件
|
||||
func (e *StreamEncoder) encodeError(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Error == nil {
|
||||
return nil
|
||||
}
|
||||
payload := map[string]any{
|
||||
"type": "error",
|
||||
"error": map[string]any{
|
||||
"type": event.Error.Type,
|
||||
"message": event.Error.Message,
|
||||
},
|
||||
}
|
||||
return e.marshalEvent("error", payload)
|
||||
}
|
||||
|
||||
// marshalEvent 序列化为 Anthropic 命名 SSE 事件
|
||||
func (e *StreamEncoder) marshalEvent(eventType string, payload map[string]any) [][]byte {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return [][]byte{[]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, data))}
|
||||
}
|
||||
298
backend/internal/conversion/anthropic/stream_encoder_test.go
Normal file
298
backend/internal/conversion/anthropic/stream_encoder_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEvent("msg_1", "claude-3")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := payload["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "msg_1", msg["id"])
|
||||
assert.Equal(t, "message", msg["type"])
|
||||
assert.Equal(t, "assistant", msg["role"])
|
||||
assert.Equal(t, []any{}, msg["content"])
|
||||
assert.Equal(t, "claude-3", msg["model"])
|
||||
assert.Nil(t, msg["stop_reason"])
|
||||
assert.Nil(t, msg["stop_sequence"])
|
||||
|
||||
usage, okU := msg["usage"].(map[string]any)
|
||||
require.True(t, okU)
|
||||
assert.Equal(t, float64(0), usage["input_tokens"])
|
||||
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := payload["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, okU := msg["usage"].(map[string]any)
|
||||
require.True(t, okU)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(50), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_delta\n"))
|
||||
assert.Contains(t, s, "你好")
|
||||
|
||||
// 验证 JSON 格式
|
||||
lines := strings.Split(s, "\n")
|
||||
var dataLine string
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
dataLine = strings.TrimPrefix(l, "data: ")
|
||||
break
|
||||
}
|
||||
}
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(dataLine), &payload))
|
||||
assert.Equal(t, "content_block_delta", payload["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStopEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_stop\n"))
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", cb["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(1, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "toolu_1",
|
||||
Name: "search",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "toolu_1")
|
||||
assert.Contains(t, s, "search")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", cb["type"])
|
||||
assert.Equal(t, "toolu_1", cb["id"])
|
||||
assert.Equal(t, "search", cb["name"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "thinking", Thinking: ""})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "thinking")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "thinking", cb["type"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
idx := 2
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: content_block_stop\n"))
|
||||
assert.Contains(t, s, "content_block_stop")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
sr := canonical.StopReasonEndTurn
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "stop_reason")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
delta, okd := payload["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||
|
||||
usage, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku, "message_delta SHALL always include usage")
|
||||
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
usage := canonical.CanonicalUsage{OutputTokens: 88}
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
Usage: &usage,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "output_tokens")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
u, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(88), u["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Ping(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewPingEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: ping\n"))
|
||||
assert.Contains(t, s, "ping")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Error(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewErrorEvent("overloaded_error", "服务过载")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: error\n"))
|
||||
assert.Contains(t, s, "overloaded_error")
|
||||
assert.Contains(t, s, "服务过载")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
chunks := e.Flush()
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_UnknownEvent_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.CanonicalStreamEvent{Type: "unknown_event_type"}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
485
backend/internal/conversion/anthropic/supplemental_test.go
Normal file
485
backend/internal/conversion/anthropic/supplemental_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeTools(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"tools": [
|
||||
{"name": "search", "description": "Search", "input_schema": {"type":"object"}},
|
||||
{"name": "calc", "input_schema": {"type":"object"}}
|
||||
]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Tools, 2)
|
||||
assert.Equal(t, "search", req.Tools[0].Name)
|
||||
assert.Equal(t, "Search", req.Tools[0].Description)
|
||||
assert.Equal(t, "calc", req.Tools[1].Name)
|
||||
}
|
||||
|
||||
func TestDecodeToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonBody string
|
||||
wantType string
|
||||
wantName string
|
||||
}{
|
||||
{
|
||||
"auto string",
|
||||
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":"auto"}`,
|
||||
"auto", "",
|
||||
},
|
||||
{
|
||||
"none string",
|
||||
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":"none"}`,
|
||||
"none", "",
|
||||
},
|
||||
{
|
||||
"any string",
|
||||
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":"any"}`,
|
||||
"any", "",
|
||||
},
|
||||
{
|
||||
"tool object",
|
||||
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"tool","name":"search"}}`,
|
||||
"tool", "search",
|
||||
},
|
||||
{
|
||||
"auto object",
|
||||
`{"model":"claude-3","max_tokens":1024,"messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"auto"}}`,
|
||||
"auto", "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := decodeRequest([]byte(tt.jsonBody))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.ToolChoice)
|
||||
assert.Equal(t, tt.wantType, req.ToolChoice.Type)
|
||||
assert.Equal(t, tt.wantName, req.ToolChoice.Name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeParameters_TopK(t *testing.T) {
|
||||
topK := 10
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"top_k": 10,
|
||||
"stop_sequences": ["STOP"]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Parameters.TopK)
|
||||
assert.Equal(t, topK, *req.Parameters.TopK)
|
||||
assert.Equal(t, []string{"STOP"}, req.Parameters.StopSequences)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MetadataUserID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"metadata": {"user_id": "user-123"}
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "user-123", req.UserID)
|
||||
}
|
||||
|
||||
func TestDecodeSystem_Empty(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"system": "",
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, req.System)
|
||||
}
|
||||
|
||||
func TestDecodeSystem_Nil(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, req.System)
|
||||
}
|
||||
|
||||
func TestDecodeThinking_WithEffort(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 5000},
|
||||
"output_config": {"effort": "high"}
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "enabled", req.Thinking.Type)
|
||||
assert.Equal(t, "high", req.Thinking.Effort)
|
||||
}
|
||||
|
||||
func TestDecodeOutputFormat_NilOutputConfig(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, req.OutputFormat)
|
||||
}
|
||||
|
||||
func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": [{"type": "tool_use", "id": "t1", "name": "fn", "input": {}}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "tool_result", "tool_use_id": "t1", "content": "result"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
lastMsg := req.Messages[len(req.Messages)-1]
|
||||
assert.Equal(t, canonical.RoleTool, lastMsg.Role)
|
||||
assert.Equal(t, "t1", lastMsg.Content[0].ToolUseID)
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_Nil(t *testing.T) {
|
||||
blocks, err := decodeContentBlocks(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_String(t *testing.T) {
|
||||
blocks, err := decodeContentBlocks("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "hello", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestParseTimestamp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want int64
|
||||
}{
|
||||
{"valid RFC3339", "2024-01-15T00:00:00Z", time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC).Unix()},
|
||||
{"empty", "", 0},
|
||||
{"invalid", "not-a-date", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, parseTimestamp(tt.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
choice *canonical.ToolChoice
|
||||
want map[string]any
|
||||
}{
|
||||
{"auto", canonical.NewToolChoiceAuto(), map[string]any{"type": "auto"}},
|
||||
{"none", canonical.NewToolChoiceNone(), map[string]any{"type": "none"}},
|
||||
{"any", canonical.NewToolChoiceAny(), map[string]any{"type": "any"}},
|
||||
{"tool", canonical.NewToolChoiceNamed("search"), map[string]any{"type": "tool", "name": "search"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := encodeToolChoice(tt.choice)
|
||||
r, ok := result.(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.want["type"], r["type"])
|
||||
assert.Equal(t, tt.want["name"], r["name"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeThinkingConfig(t *testing.T) {
|
||||
budget := 5000
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *canonical.ThinkingConfig
|
||||
want map[string]any
|
||||
}{
|
||||
{"enabled", &canonical.ThinkingConfig{Type: "enabled", BudgetTokens: &budget}, map[string]any{"type": "enabled", "budget_tokens": float64(5000)}},
|
||||
{"disabled", &canonical.ThinkingConfig{Type: "disabled"}, map[string]any{"type": "disabled"}},
|
||||
{"adaptive", &canonical.ThinkingConfig{Type: "adaptive"}, map[string]any{"type": "adaptive"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := encodeThinkingConfig(tt.cfg)
|
||||
assert.Equal(t, tt.want["type"], result["type"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeRequest_PublicFields(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
parallel := false
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
UserID: "user-123",
|
||||
ParallelToolUse: ¶llel,
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, map[string]any{"user_id": "user-123"}, result["metadata"])
|
||||
assert.Equal(t, true, result["disable_parallel_tool_use"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_DefaultMaxTokens(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, float64(4096), result["max_tokens"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_TopK(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
topK := 10
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens, TopK: &topK},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, float64(10), result["top_k"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_WithTools(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Tools: []canonical.CanonicalTool{
|
||||
{Name: "search", Description: "Search things", InputSchema: json.RawMessage(`{"type":"object"}`)},
|
||||
},
|
||||
ToolChoice: canonical.NewToolChoiceAuto(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tools, okt := result["tools"].([]any)
|
||||
require.True(t, okt)
|
||||
assert.Len(t, tools, 1)
|
||||
tool, okt2 := tools[0].(map[string]any)
|
||||
require.True(t, okt2)
|
||||
assert.Equal(t, "search", tool["name"])
|
||||
assert.Equal(t, "Search things", tool["description"])
|
||||
tc, oktc := result["tool_choice"].(map[string]any)
|
||||
require.True(t, oktc)
|
||||
assert.Equal(t, "auto", tc["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ThinkingWithEffort(t *testing.T) {
|
||||
maxTokens := 1024
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "claude-3",
|
||||
Parameters: canonical.RequestParameters{MaxTokens: &maxTokens},
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "enabled", Effort: "high"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
oc, ok := result["output_config"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "high", oc["effort"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
cacheRead := 30
|
||||
cacheCreation := 10
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-1",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cacheRead,
|
||||
CacheCreationTokens: &cacheCreation,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stopReason canonical.StopReason
|
||||
want string
|
||||
}{
|
||||
{"end_turn", canonical.StopReasonEndTurn, "end_turn"},
|
||||
{"max_tokens", canonical.StopReasonMaxTokens, "max_tokens"},
|
||||
{"tool_use", canonical.StopReasonToolUse, "tool_use"},
|
||||
{"stop_sequence", canonical.StopReasonStopSequence, "stop_sequence"},
|
||||
{"refusal", canonical.StopReasonRefusal, "refusal"},
|
||||
{"content_filter→end_turn", canonical.StopReasonContentFilter, "end_turn"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sr := tt.stopReason
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "r1",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
}
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, tt.want, result["stop_reason"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeSystem_SystemBlocks(t *testing.T) {
|
||||
result := encodeSystem([]canonical.SystemBlock{{Text: "part1"}, {Text: "part2"}})
|
||||
blocks, ok := result.([]map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, blocks, 2)
|
||||
assert.Equal(t, "part1", blocks[0]["text"])
|
||||
}
|
||||
|
||||
func TestEncodeModelInfoResponse(t *testing.T) {
|
||||
info := &canonical.CanonicalModelInfo{
|
||||
ID: "claude-3-opus",
|
||||
Name: "Claude 3 Opus",
|
||||
Created: time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC).Unix(),
|
||||
}
|
||||
|
||||
body, err := encodeModelInfoResponse(info)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "claude-3-opus", result["id"])
|
||||
assert.Equal(t, "Claude 3 Opus", result["display_name"])
|
||||
}
|
||||
|
||||
func TestDecodeModelInfoResponse(t *testing.T) {
|
||||
body := []byte(`{"id":"claude-3-opus","type":"model","display_name":"Claude 3 Opus","created_at":"2024-01-15T00:00:00Z"}`)
|
||||
info, err := decodeModelInfoResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "claude-3-opus", info.ID)
|
||||
assert.Equal(t, "Claude 3 Opus", info.Name)
|
||||
assert.NotEqual(t, int64(0), info.Created)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_PauseTurn(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "msg-1", "type": "message", "role": "assistant", "model": "claude-3",
|
||||
"content": [{"type": "text", "text": "ok"}],
|
||||
"stop_reason": "pause_turn",
|
||||
"usage": {"input_tokens": 1, "output_tokens": 1}
|
||||
}`)
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, canonical.StopReason("pause_turn"), *resp.StopReason)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_NoStopReason(t *testing.T) {
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "msg-1",
|
||||
Model: "claude-3",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MaxTokensZero(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "claude-3",
|
||||
"max_tokens": 0,
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, req.Parameters.MaxTokens)
|
||||
}
|
||||
183
backend/internal/conversion/anthropic/types.go
Normal file
183
backend/internal/conversion/anthropic/types.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// MessagesRequest Anthropic Messages 请求
|
||||
type MessagesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System any `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Metadata *RequestMetadata `json:"metadata,omitempty"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
// RequestMetadata 请求元数据
|
||||
type RequestMetadata struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig 思考配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
Display string `json:"display,omitempty"`
|
||||
}
|
||||
|
||||
// OutputConfig 输出配置
|
||||
type OutputConfig struct {
|
||||
Format *OutputFormatConfig `json:"format,omitempty"`
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// OutputFormatConfig 输出格式配置
|
||||
type OutputFormatConfig struct {
|
||||
Type string `json:"type"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
}
|
||||
|
||||
// Message Anthropic 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
// TextContent 文本内容块
|
||||
type TextContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ToolUseContent 工具调用内容块
|
||||
type ToolUseContent struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Input json.RawMessage `json:"input"`
|
||||
}
|
||||
|
||||
// ToolResultContent 工具结果内容块
|
||||
type ToolResultContent struct {
|
||||
Type string `json:"type"`
|
||||
ToolUseID string `json:"tool_use_id"`
|
||||
Content any `json:"content"`
|
||||
IsError *bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingContent 思考内容块
|
||||
type ThinkingContent struct {
|
||||
Type string `json:"type"`
|
||||
Thinking string `json:"thinking"`
|
||||
}
|
||||
|
||||
// RedactedThinkingContent 已编辑思考内容块
|
||||
type RedactedThinkingContent struct {
|
||||
Type string `json:"type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// Tool Anthropic 工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// MessagesResponse Anthropic Messages 响应
|
||||
type MessagesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
StopDetails any `json:"stop_details,omitempty"`
|
||||
Container any `json:"container,omitempty"`
|
||||
Usage ResponseUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// ContentBlock Anthropic 响应内容块
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseUsage 响应用量
|
||||
type ResponseUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadInputTokens *int `json:"cache_read_input_tokens,omitempty"`
|
||||
CacheCreationInputTokens *int `json:"cache_creation_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ModelsResponse Anthropic 模型列表响应
|
||||
type ModelsResponse struct {
|
||||
Data []ModelItem `json:"data"`
|
||||
HasMore bool `json:"has_more"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
}
|
||||
|
||||
// ModelItem Anthropic 模型项
|
||||
type ModelItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfoResponse Anthropic 模型详情响应
|
||||
type ModelInfoResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest Anthropic 不支持嵌入,但定义类型用于接口兼容
|
||||
type EmbeddingRequest struct{}
|
||||
|
||||
// EmbeddingResponse Anthropic 不支持嵌入
|
||||
type EmbeddingResponse struct{}
|
||||
|
||||
// RerankRequest Anthropic 不支持重排序
|
||||
type RerankRequest struct{}
|
||||
|
||||
// RerankResponse Anthropic 不支持重排序
|
||||
type RerankResponse struct{}
|
||||
|
||||
// ErrorResponse Anthropic 错误响应
|
||||
type ErrorResponse struct {
|
||||
Type string `json:"type"`
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// SSEEvent SSE 事件
|
||||
type SSEEvent struct {
|
||||
EventType string
|
||||
Data json.RawMessage
|
||||
}
|
||||
71
backend/internal/conversion/canonical/extended.go
Normal file
71
backend/internal/conversion/canonical/extended.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package canonical
|
||||
|
||||
// CanonicalModel 规范模型
|
||||
type CanonicalModel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalModelList 规范模型列表
|
||||
type CanonicalModelList struct {
|
||||
Models []CanonicalModel `json:"models"`
|
||||
}
|
||||
|
||||
// CanonicalModelInfo 规范模型详情
|
||||
type CanonicalModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
OwnedBy string `json:"owned_by,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalEmbeddingRequest 规范嵌入请求
|
||||
type CanonicalEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"` // string 或 []string
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalEmbeddingResponse 规范嵌入响应
|
||||
type CanonicalEmbeddingResponse struct {
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage EmbeddingUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"` // 根据格式不同可能是 []float64 或 base64 字符串
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// CanonicalRerankRequest 规范重排序请求
|
||||
type CanonicalRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalRerankResponse 规范重排序响应
|
||||
type CanonicalRerankResponse struct {
|
||||
Results []RerankResult `json:"results"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// RerankResult 重排序结果项
|
||||
type RerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document *string `json:"document,omitempty"`
|
||||
}
|
||||
156
backend/internal/conversion/canonical/stream.go
Normal file
156
backend/internal/conversion/canonical/stream.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package canonical
|
||||
|
||||
// StreamEventType 流式事件类型枚举
|
||||
type StreamEventType string
|
||||
|
||||
const (
|
||||
EventMessageStart StreamEventType = "message_start"
|
||||
EventContentBlockStart StreamEventType = "content_block_start"
|
||||
EventContentBlockDelta StreamEventType = "content_block_delta"
|
||||
EventContentBlockStop StreamEventType = "content_block_stop"
|
||||
EventMessageDelta StreamEventType = "message_delta"
|
||||
EventMessageStop StreamEventType = "message_stop"
|
||||
EventError StreamEventType = "error"
|
||||
EventPing StreamEventType = "ping"
|
||||
)
|
||||
|
||||
// DeltaType 增量类型枚举
|
||||
type DeltaType string
|
||||
|
||||
const (
|
||||
DeltaTypeText DeltaType = "text_delta"
|
||||
DeltaTypeInputJSON DeltaType = "input_json_delta"
|
||||
DeltaTypeThinking DeltaType = "thinking_delta"
|
||||
)
|
||||
|
||||
// StreamDelta 流式增量联合体
|
||||
type StreamDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
PartialJSON string `json:"partial_json,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// StreamContentBlock 流式内容块联合体
|
||||
type StreamContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalStreamEvent 规范流式事件联合体
|
||||
type CanonicalStreamEvent struct {
|
||||
Type StreamEventType `json:"type"`
|
||||
|
||||
// MessageStartEvent
|
||||
Message *StreamMessage `json:"message,omitempty"`
|
||||
|
||||
// ContentBlockStartEvent / ContentBlockDeltaEvent / ContentBlockStopEvent
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *StreamContentBlock `json:"content_block,omitempty"`
|
||||
Delta *StreamDelta `json:"delta,omitempty"`
|
||||
|
||||
// MessageDeltaEvent
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage *CanonicalUsage `json:"usage,omitempty"`
|
||||
|
||||
// ErrorEvent
|
||||
Error *StreamError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// StreamMessage 流式消息摘要
|
||||
type StreamMessage struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage *CanonicalUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// StreamError 流式错误
|
||||
type StreamError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewMessageStartEvent 创建消息开始事件
|
||||
func NewMessageStartEvent(id, model string) CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageStart,
|
||||
Message: &StreamMessage{ID: id, Model: model},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageStartEventWithUsage 创建带用量的消息开始事件
|
||||
func NewMessageStartEventWithUsage(id, model string, usage *CanonicalUsage) CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageStart,
|
||||
Message: &StreamMessage{ID: id, Model: model, Usage: usage},
|
||||
}
|
||||
}
|
||||
|
||||
// NewContentBlockStartEvent 创建内容块开始事件
|
||||
func NewContentBlockStartEvent(index int, block StreamContentBlock) CanonicalStreamEvent {
|
||||
idx := index
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventContentBlockStart,
|
||||
Index: &idx,
|
||||
ContentBlock: &block,
|
||||
}
|
||||
}
|
||||
|
||||
// NewContentBlockDeltaEvent 创建内容块增量事件
|
||||
func NewContentBlockDeltaEvent(index int, delta StreamDelta) CanonicalStreamEvent {
|
||||
idx := index
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventContentBlockDelta,
|
||||
Index: &idx,
|
||||
Delta: &delta,
|
||||
}
|
||||
}
|
||||
|
||||
// NewContentBlockStopEvent 创建内容块结束事件
|
||||
func NewContentBlockStopEvent(index int) CanonicalStreamEvent {
|
||||
idx := index
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageDeltaEvent 创建消息增量事件
|
||||
func NewMessageDeltaEvent(stopReason StopReason) CanonicalStreamEvent {
|
||||
sr := stopReason
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageDeltaEventWithUsage 创建带用量的消息增量事件
|
||||
func NewMessageDeltaEventWithUsage(stopReason StopReason, usage *CanonicalUsage) CanonicalStreamEvent {
|
||||
sr := stopReason
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessageStopEvent 创建消息结束事件
|
||||
func NewMessageStopEvent() CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{Type: EventMessageStop}
|
||||
}
|
||||
|
||||
// NewErrorEvent 创建错误事件
|
||||
func NewErrorEvent(errType, message string) CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{
|
||||
Type: EventError,
|
||||
Error: &StreamError{Type: errType, Message: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewPingEvent 创建心跳事件
|
||||
func NewPingEvent() CanonicalStreamEvent {
|
||||
return CanonicalStreamEvent{Type: EventPing}
|
||||
}
|
||||
208
backend/internal/conversion/canonical/types.go
Normal file
208
backend/internal/conversion/canonical/types.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package canonical
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// MessageRole 消息角色枚举
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
RoleSystem MessageRole = "system"
|
||||
RoleUser MessageRole = "user"
|
||||
RoleAssistant MessageRole = "assistant"
|
||||
RoleTool MessageRole = "tool"
|
||||
)
|
||||
|
||||
// StopReason 停止原因枚举
|
||||
type StopReason string
|
||||
|
||||
const (
|
||||
StopReasonEndTurn StopReason = "end_turn"
|
||||
StopReasonMaxTokens StopReason = "max_tokens"
|
||||
StopReasonToolUse StopReason = "tool_use"
|
||||
StopReasonStopSequence StopReason = "stop_sequence"
|
||||
StopReasonContentFilter StopReason = "content_filter"
|
||||
StopReasonRefusal StopReason = "refusal"
|
||||
)
|
||||
|
||||
// SystemBlock 系统消息块
|
||||
type SystemBlock struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ContentBlock 使用 type 字段的 discriminated union
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// TextBlock
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// ToolUseBlock
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
|
||||
// ToolResultBlock
|
||||
ToolUseID string `json:"tool_use_id,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
IsError *bool `json:"is_error,omitempty"`
|
||||
|
||||
// ThinkingBlock
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
// NewTextBlock 创建文本块
|
||||
func NewTextBlock(text string) ContentBlock {
|
||||
return ContentBlock{Type: "text", Text: text}
|
||||
}
|
||||
|
||||
// NewToolUseBlock 创建工具调用块
|
||||
func NewToolUseBlock(id, name string, input json.RawMessage) ContentBlock {
|
||||
return ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
}
|
||||
|
||||
// NewToolResultBlock 创建工具结果块
|
||||
func NewToolResultBlock(toolUseID string, content string, isError bool) ContentBlock {
|
||||
errFlag := &isError
|
||||
return ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: toolUseID,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: errFlag,
|
||||
}
|
||||
}
|
||||
|
||||
// NewThinkingBlock 创建思考块
|
||||
func NewThinkingBlock(thinking string) ContentBlock {
|
||||
return ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
}
|
||||
|
||||
// CanonicalMessage 规范消息
|
||||
type CanonicalMessage struct {
|
||||
Role MessageRole `json:"role"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
}
|
||||
|
||||
// CanonicalTool 规范工具定义
|
||||
type CanonicalTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
}
|
||||
|
||||
// ToolChoice 工具选择联合体
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// NewToolChoiceAuto 创建自动工具选择
|
||||
func NewToolChoiceAuto() *ToolChoice {
|
||||
return &ToolChoice{Type: "auto"}
|
||||
}
|
||||
|
||||
// NewToolChoiceNone 创建无工具选择
|
||||
func NewToolChoiceNone() *ToolChoice {
|
||||
return &ToolChoice{Type: "none"}
|
||||
}
|
||||
|
||||
// NewToolChoiceAny 创建任意工具选择
|
||||
func NewToolChoiceAny() *ToolChoice {
|
||||
return &ToolChoice{Type: "any"}
|
||||
}
|
||||
|
||||
// NewToolChoiceNamed 创建指定工具选择
|
||||
func NewToolChoiceNamed(name string) *ToolChoice {
|
||||
return &ToolChoice{Type: "tool", Name: name}
|
||||
}
|
||||
|
||||
// RequestParameters 请求参数
|
||||
type RequestParameters struct {
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
}
|
||||
|
||||
// ThinkingConfig 思考配置
|
||||
type ThinkingConfig struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// OutputFormat 输出格式联合体
|
||||
type OutputFormat struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalRequest 规范请求
|
||||
type CanonicalRequest struct {
|
||||
Model string `json:"model"`
|
||||
System any `json:"system,omitempty"` // nil, string, or []SystemBlock
|
||||
Messages []CanonicalMessage `json:"messages"`
|
||||
Tools []CanonicalTool `json:"tools,omitempty"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Parameters RequestParameters `json:"parameters"`
|
||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
OutputFormat *OutputFormat `json:"output_format,omitempty"`
|
||||
ParallelToolUse *bool `json:"parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalUsage 规范用量
|
||||
type CanonicalUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheReadTokens *int `json:"cache_read_tokens,omitempty"`
|
||||
CacheCreationTokens *int `json:"cache_creation_tokens,omitempty"`
|
||||
ReasoningTokens *int `json:"reasoning_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CanonicalResponse 规范响应
|
||||
type CanonicalResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason *StopReason `json:"stop_reason,omitempty"`
|
||||
Usage CanonicalUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// GetSystemString 获取系统消息字符串
|
||||
func (r *CanonicalRequest) GetSystemString() string {
|
||||
switch v := r.System.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []SystemBlock:
|
||||
var result string
|
||||
for i, b := range v {
|
||||
if i > 0 {
|
||||
result += "\n\n"
|
||||
}
|
||||
result += b.Text
|
||||
}
|
||||
return result
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSystemString 设置系统消息字符串
|
||||
func (r *CanonicalRequest) SetSystemString(s string) {
|
||||
if s == "" {
|
||||
r.System = nil
|
||||
} else {
|
||||
r.System = s
|
||||
}
|
||||
}
|
||||
114
backend/internal/conversion/canonical/types_test.go
Normal file
114
backend/internal/conversion/canonical/types_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package canonical
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetSystemString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system any
|
||||
want string
|
||||
}{
|
||||
{"string", "hello", "hello"},
|
||||
{"nil", nil, ""},
|
||||
{"empty string", "", ""},
|
||||
{"system blocks", []SystemBlock{{Text: "part1"}, {Text: "part2"}}, "part1\n\npart2"},
|
||||
{"single block", []SystemBlock{{Text: "only"}}, "only"},
|
||||
{"other type", 123, "123"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &CanonicalRequest{System: tt.system}
|
||||
assert.Equal(t, tt.want, req.GetSystemString())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetSystemString(t *testing.T) {
|
||||
req := &CanonicalRequest{}
|
||||
|
||||
req.SetSystemString("hello")
|
||||
assert.Equal(t, "hello", req.System)
|
||||
|
||||
req.SetSystemString("")
|
||||
assert.Nil(t, req.System)
|
||||
}
|
||||
|
||||
func TestNewTextBlock(t *testing.T) {
|
||||
b := NewTextBlock("hello")
|
||||
assert.Equal(t, "text", b.Type)
|
||||
assert.Equal(t, "hello", b.Text)
|
||||
}
|
||||
|
||||
func TestNewToolUseBlock(t *testing.T) {
|
||||
input := json.RawMessage(`{"key":"val"}`)
|
||||
b := NewToolUseBlock("id-1", "tool_name", input)
|
||||
assert.Equal(t, "tool_use", b.Type)
|
||||
assert.Equal(t, "id-1", b.ID)
|
||||
assert.Equal(t, "tool_name", b.Name)
|
||||
assert.Equal(t, input, b.Input)
|
||||
}
|
||||
|
||||
func TestNewToolResultBlock(t *testing.T) {
|
||||
b := NewToolResultBlock("tool-1", "result", false)
|
||||
assert.Equal(t, "tool_result", b.Type)
|
||||
assert.Equal(t, "tool-1", b.ToolUseID)
|
||||
assert.NotNil(t, b.IsError)
|
||||
assert.False(t, *b.IsError)
|
||||
}
|
||||
|
||||
func TestNewThinkingBlock(t *testing.T) {
|
||||
b := NewThinkingBlock("thought")
|
||||
assert.Equal(t, "thinking", b.Type)
|
||||
assert.Equal(t, "thought", b.Thinking)
|
||||
}
|
||||
|
||||
func TestNewToolChoice(t *testing.T) {
|
||||
assert.Equal(t, &ToolChoice{Type: "auto"}, NewToolChoiceAuto())
|
||||
assert.Equal(t, &ToolChoice{Type: "none"}, NewToolChoiceNone())
|
||||
assert.Equal(t, &ToolChoice{Type: "any"}, NewToolChoiceAny())
|
||||
assert.Equal(t, &ToolChoice{Type: "tool", Name: "fn"}, NewToolChoiceNamed("fn"))
|
||||
}
|
||||
|
||||
func TestCanonicalRequest_RoundTrip(t *testing.T) {
|
||||
req := &CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
System: "system prompt",
|
||||
Messages: []CanonicalMessage{{Role: RoleUser, Content: []ContentBlock{NewTextBlock("hi")}}},
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded CanonicalRequest
|
||||
require.NoError(t, json.Unmarshal(data, &decoded))
|
||||
assert.Equal(t, "gpt-4", decoded.Model)
|
||||
assert.Equal(t, "system prompt", decoded.System)
|
||||
assert.True(t, decoded.Stream)
|
||||
}
|
||||
|
||||
func TestCanonicalResponse_RoundTrip(t *testing.T) {
|
||||
sr := StopReasonEndTurn
|
||||
resp := &CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []ContentBlock{NewTextBlock("hello")},
|
||||
StopReason: &sr,
|
||||
Usage: CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded CanonicalResponse
|
||||
require.NoError(t, json.Unmarshal(data, &decoded))
|
||||
assert.Equal(t, "resp-1", decoded.ID)
|
||||
assert.Equal(t, StopReasonEndTurn, *decoded.StopReason)
|
||||
}
|
||||
453
backend/internal/conversion/engine.go
Normal file
453
backend/internal/conversion/engine.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// HTTPRequestSpec HTTP 请求规格
|
||||
type HTTPRequestSpec struct {
|
||||
URL string `json:"url"`
|
||||
Method string `json:"method"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body []byte `json:"body"`
|
||||
}
|
||||
|
||||
// HTTPResponseSpec HTTP 响应规格
|
||||
type HTTPResponseSpec struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body []byte `json:"body"`
|
||||
}
|
||||
|
||||
// ConversionEngine 转换引擎门面
|
||||
type ConversionEngine struct {
|
||||
registry AdapterRegistry
|
||||
middlewareChain *MiddlewareChain
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewConversionEngine 创建转换引擎
|
||||
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
|
||||
return &ConversionEngine{
|
||||
registry: registry,
|
||||
middlewareChain: NewMiddlewareChain(),
|
||||
logger: pkglogger.WithModule(logger, "conversion.engine"),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAdapter 注册协议适配器
|
||||
func (e *ConversionEngine) RegisterAdapter(adapter ProtocolAdapter) error {
|
||||
return e.registry.Register(adapter)
|
||||
}
|
||||
|
||||
// GetRegistry 返回注册表(供外部使用)
|
||||
func (e *ConversionEngine) GetRegistry() AdapterRegistry {
|
||||
return e.registry
|
||||
}
|
||||
|
||||
// Use 添加中间件
|
||||
func (e *ConversionEngine) Use(mw ConversionMiddleware) {
|
||||
e.middlewareChain.Use(mw)
|
||||
}
|
||||
|
||||
// IsPassthrough 判断是否同协议透传
|
||||
func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string) bool {
|
||||
if clientProtocol != providerProtocol {
|
||||
return false
|
||||
}
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return adapter.SupportsPassthrough()
|
||||
}
|
||||
|
||||
// ConvertHttpRequest 转换 HTTP 请求
|
||||
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
||||
nativePath, rawQuery := splitRequestPath(spec.URL)
|
||||
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
interfaceType := providerAdapter.DetectInterfaceType(nativePath)
|
||||
rewrittenBody := spec.Body
|
||||
|
||||
// 对于 Chat/Embedding/Rerank 接口,改写请求体中的 model 字段
|
||||
if interfaceType == InterfaceTypeChat || interfaceType == InterfaceTypeEmbeddings || interfaceType == InterfaceTypeRerank {
|
||||
if len(spec.Body) > 0 && provider.ModelName != "" {
|
||||
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||||
if err != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||||
zap.Error(err),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
rewrittenBody = spec.Body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerAdapter.BuildHeaders(provider),
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("未找到客户端适配器 %s: %w", clientProtocol, err)
|
||||
}
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("未找到服务端适配器 %s: %w", providerProtocol, err)
|
||||
}
|
||||
|
||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerHeaders,
|
||||
Body: providerBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConvertHttpResponse 转换 HTTP 响应,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProtocol, providerProtocol string, interfaceType InterfaceType, modelOverride string) (*HTTPResponseSpec, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
if modelOverride != "" && len(spec.Body) > 0 {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if rewriteErr != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||
zap.Error(rewriteErr),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
} else {
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
convertedBody, err := e.convertResponseBody(interfaceType, clientAdapter, providerAdapter, spec.Body, modelOverride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: convertedBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateStreamConverter 创建流式转换器,modelOverride 用于跨协议场景覆写 model 字段
|
||||
func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtocol string, modelOverride string, interfaceType InterfaceType) (StreamConverter, error) {
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||
if modelOverride != "" {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||
}
|
||||
}
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientAdapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: interfaceType,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
return NewCanonicalStreamConverterWithMiddleware(
|
||||
providerAdapter.CreateStreamDecoder(),
|
||||
clientAdapter.CreateStreamEncoder(),
|
||||
e.middlewareChain,
|
||||
ctx,
|
||||
clientProtocol,
|
||||
providerProtocol,
|
||||
modelOverride,
|
||||
), nil
|
||||
}
|
||||
|
||||
// convertBody 转换请求体
|
||||
func (e *ConversionEngine) convertBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
switch interfaceType {
|
||||
case InterfaceTypeChat:
|
||||
return e.convertChatBody(clientAdapter, providerAdapter, provider, body)
|
||||
case InterfaceTypeModels, InterfaceTypeModelInfo:
|
||||
return body, nil
|
||||
case InterfaceTypeEmbeddings:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertEmbeddingBody(clientAdapter, providerAdapter, provider, body)
|
||||
case InterfaceTypeRerank:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertRerankBody(clientAdapter, providerAdapter, provider, body)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
// convertResponseBody 转换响应体,modelOverride 非空时在 canonical 层面覆写 Model 字段
|
||||
func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
switch interfaceType {
|
||||
case InterfaceTypeChat:
|
||||
return e.convertChatResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeModels:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeModels) || !providerAdapter.SupportsInterface(InterfaceTypeModels) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertModelsResponseBody(clientAdapter, providerAdapter, body)
|
||||
case InterfaceTypeModelInfo:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeModelInfo) || !providerAdapter.SupportsInterface(InterfaceTypeModelInfo) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertModelInfoResponseBody(clientAdapter, providerAdapter, body)
|
||||
case InterfaceTypeEmbeddings:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeEmbeddings) || !providerAdapter.SupportsInterface(InterfaceTypeEmbeddings) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertEmbeddingResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
case InterfaceTypeRerank:
|
||||
if !clientAdapter.SupportsInterface(InterfaceTypeRerank) || !providerAdapter.SupportsInterface(InterfaceTypeRerank) {
|
||||
return body, nil
|
||||
}
|
||||
return e.convertRerankResponseBody(clientAdapter, providerAdapter, body, modelOverride)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
||||
if err != nil {
|
||||
return nil, NewRequestJSONParseError("解码请求失败", err)
|
||||
}
|
||||
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
canonicalReq, err = e.middlewareChain.Apply(canonicalReq, clientAdapter.ProtocolName(), providerAdapter.ProtocolName(), ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if containsUnsupportedMultimodal(canonicalReq) {
|
||||
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
|
||||
}
|
||||
|
||||
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码请求失败").WithCause(err)
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, NewResponseJSONParseError("解码响应失败", err)
|
||||
}
|
||||
if modelOverride != "" {
|
||||
canonicalResp.Model = modelOverride
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeResponse(canonicalResp)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeEncodingFailure, "编码响应失败").WithCause(err)
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
models, err := providerAdapter.DecodeModelsResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeRerankRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
|
||||
if decodeErr == nil {
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// DetectInterfaceType 检测接口类型
|
||||
func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string) (InterfaceType, error) {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return InterfaceTypePassthrough, err
|
||||
}
|
||||
nativePath, _ = splitRequestPath(nativePath)
|
||||
return adapter.DetectInterfaceType(nativePath), nil
|
||||
}
|
||||
|
||||
// EncodeError 使用客户端适配器编码错误
|
||||
func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol string) ([]byte, int, error) {
|
||||
adapter, adapterErr := e.registry.Get(clientProtocol)
|
||||
if adapterErr != nil {
|
||||
fallback := map[string]any{
|
||||
"error": map[string]string{
|
||||
"message": err.Error(),
|
||||
"type": "internal_error",
|
||||
},
|
||||
}
|
||||
body, marshalErr := json.Marshal(fallback)
|
||||
if marshalErr == nil {
|
||||
return body, 500, nil
|
||||
}
|
||||
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
|
||||
}
|
||||
body, statusCode := adapter.EncodeError(err)
|
||||
return body, statusCode, nil
|
||||
}
|
||||
|
||||
func splitRequestPath(rawPath string) (string, string) {
|
||||
path, query, found := strings.Cut(rawPath, "?")
|
||||
if !found {
|
||||
return rawPath, ""
|
||||
}
|
||||
return path, query
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
if strings.Contains(path, "?") {
|
||||
return path + "&" + rawQuery
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
if baseURL == "" {
|
||||
return path
|
||||
}
|
||||
if path == "" {
|
||||
return baseURL
|
||||
}
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
|
||||
if req == nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "image", "audio", "video", "file":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
63
backend/internal/conversion/engine_adapter_test.go
Normal file
63
backend/internal/conversion/engine_adapter_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package conversion_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
adapter conversion.ProtocolAdapter
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
baseURL string
|
||||
nativePath string
|
||||
expectedURL string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "openai base url includes version path",
|
||||
adapter: openai.NewAdapter(),
|
||||
clientProtocol: "openai",
|
||||
providerProtocol: "openai",
|
||||
baseURL: "http://example.com/v1",
|
||||
nativePath: "/chat/completions",
|
||||
expectedURL: "http://example.com/v1/chat/completions",
|
||||
body: []byte(`{"model":"gpt-4","messages":[]}`),
|
||||
},
|
||||
{
|
||||
name: "anthropic native path keeps v1",
|
||||
adapter: anthropic.NewAdapter(),
|
||||
clientProtocol: "anthropic",
|
||||
providerProtocol: "anthropic",
|
||||
baseURL: "http://example.com",
|
||||
nativePath: "/v1/messages",
|
||||
expectedURL: "http://example.com/v1/messages",
|
||||
body: []byte(`{"model":"claude","messages":[]}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(tt.adapter))
|
||||
|
||||
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
|
||||
URL: tt.nativePath,
|
||||
Method: "POST",
|
||||
Body: tt.body,
|
||||
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedURL, out.URL)
|
||||
})
|
||||
}
|
||||
}
|
||||
379
backend/internal/conversion/engine_supplemental_test.go
Normal file
379
backend/internal/conversion/engine_supplemental_test.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversionError_WithProviderProtocol(t *testing.T) {
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "test").WithProviderProtocol("anthropic")
|
||||
assert.Equal(t, "anthropic", err.ProviderProtocol)
|
||||
}
|
||||
|
||||
func TestConversionError_WithInterfaceType(t *testing.T) {
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "test").WithInterfaceType("CHAT")
|
||||
assert.Equal(t, "CHAT", err.InterfaceType)
|
||||
}
|
||||
|
||||
func TestConversionError_FullBuilder(t *testing.T) {
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "bad").
|
||||
WithClientProtocol("openai").
|
||||
WithProviderProtocol("anthropic").
|
||||
WithInterfaceType("CHAT").
|
||||
WithDetail("field", "model").
|
||||
WithCause(errors.New("root"))
|
||||
|
||||
assert.Equal(t, ErrorCodeInvalidInput, err.Code)
|
||||
assert.Equal(t, "openai", err.ClientProtocol)
|
||||
assert.Equal(t, "anthropic", err.ProviderProtocol)
|
||||
assert.Equal(t, "CHAT", err.InterfaceType)
|
||||
assert.Equal(t, "model", err.Details["field"])
|
||||
assert.Equal(t, "root", err.Cause.Error())
|
||||
}
|
||||
|
||||
func TestEngine_Use(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
called := false
|
||||
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
called = true
|
||||
return req, nil
|
||||
}})
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return &canonical.CanonicalRequest{Model: "test"}, nil
|
||||
}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||
return json.Marshal(req)
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return nil, errors.New("decode failed")
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("", "", ""))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||
return nil, errors.New("encode failed")
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
_, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/chat/completions", Method: "POST", Body: []byte(`{}`),
|
||||
}, "client", "provider", NewTargetProvider("", "", ""))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return json.Marshal(map[string]string{"id": resp.ID})
|
||||
}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return &canonical.CanonicalResponse{ID: "resp-1", Model: "test"}, nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"resp-1"}`),
|
||||
}, "client", "provider", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Contains(t, string(result.Body), "resp-1")
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return nil, errors.New("decode error")
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
|
||||
_, err := engine.ConvertHttpResponse(HTTPResponseSpec{Body: []byte(`{}`)}, "client", "provider", InterfaceTypeChat, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeEmbeddings
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return &canonical.CanonicalRequest{Model: "test"}, nil
|
||||
}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.ifaceType = InterfaceTypeEmbeddings
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/embeddings", Method: "POST", Body: []byte(`{"model":"text-embedding","input":"hello"}`),
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeRerank
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.ifaceType = InterfaceTypeRerank
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/v1/rerank", Method: "POST", Body: []byte(`{"model":"rerank","query":"test","documents":["a"]}`),
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", "model"))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeEmbeddings, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"results":[],"model":"test"}`),
|
||||
}, "client", "provider", InterfaceTypeRerank, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeModels
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.ifaceType = InterfaceTypeModels
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"object":"list","data":[]}`)
|
||||
result, err := engine.ConvertHttpRequest(HTTPRequestSpec{
|
||||
URL: "/models", Method: "GET", Body: body,
|
||||
}, "client", "provider", NewTargetProvider("https://example.com", "key", ""))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result.Body)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"object":"list","data":[]}`),
|
||||
}, "client", "provider", InterfaceTypeModels, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
result, err := engine.ConvertHttpResponse(HTTPResponseSpec{
|
||||
StatusCode: 200, Body: []byte(`{"id":"gpt-4","object":"model"}`),
|
||||
}, "client", "provider", InterfaceTypeModelInfo, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestRegistry_ListProtocols(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
_ = registry.Register(newMockAdapter("openai", true))
|
||||
_ = registry.Register(newMockAdapter("anthropic", true))
|
||||
|
||||
protocols := registry.ListProtocols()
|
||||
assert.Len(t, protocols, 2)
|
||||
assert.Contains(t, protocols, "openai")
|
||||
assert.Contains(t, protocols, "anthropic")
|
||||
}
|
||||
|
||||
func TestRegistry_ConcurrentAccess(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
done := make(chan bool, 2)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
_ = registry.Register(newMockAdapter("proto-"+string(rune(i)), true))
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
_, _ = registry.Get("proto-" + string(rune(i)))
|
||||
}
|
||||
_ = registry.ListProtocols()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestNewConversionContext(t *testing.T) {
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
assert.NotEmpty(t, ctx.ConversionID)
|
||||
assert.Equal(t, InterfaceTypeChat, ctx.InterfaceType)
|
||||
assert.NotNil(t, ctx.Metadata)
|
||||
}
|
||||
|
||||
type testMiddleware struct {
|
||||
fn func(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error)
|
||||
}
|
||||
|
||||
func (m *testMiddleware) Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
if m.fn != nil {
|
||||
return m.fn(req, clientProtocol, providerProtocol, ctx)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *testMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
|
||||
return event, nil
|
||||
}
|
||||
|
||||
var _ = json.Marshal
|
||||
|
||||
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return nil, errors.New("decode embedding failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"text-embedding","input":"hello"}`)
|
||||
result, err := engine.convertEmbeddingBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return nil, errors.New("decode rerank failed")
|
||||
}
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"model":"rerank","query":"test","documents":["a"]}`)
|
||||
result, err := engine.convertRerankBody(clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
body := []byte(`{"test":"data"}`)
|
||||
result, err := engine.convertBody(InterfaceType("UNKNOWN"), clientAdapter, providerAdapter, NewTargetProvider("", "", ""), body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
733
backend/internal/conversion/engine_test.go
Normal file
733
backend/internal/conversion/engine_test.go
Normal file
@@ -0,0 +1,733 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// mockProtocolAdapter 模拟协议适配器
|
||||
type mockProtocolAdapter struct {
|
||||
protocolName string
|
||||
passthrough bool
|
||||
ifaceType InterfaceType
|
||||
supportsIface map[InterfaceType]bool
|
||||
decodeReqFn func([]byte) (*canonical.CanonicalRequest, error)
|
||||
encodeReqFn func(*canonical.CanonicalRequest, *TargetProvider) ([]byte, error)
|
||||
decodeRespFn func([]byte) (*canonical.CanonicalResponse, error)
|
||||
encodeRespFn func(*canonical.CanonicalResponse) ([]byte, error)
|
||||
streamDecoderFn func() StreamDecoder
|
||||
streamEncoderFn func() StreamEncoder
|
||||
rewriteReqFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
rewriteRespFn func([]byte, string, InterfaceType) ([]byte, error)
|
||||
decodeEmbeddingReqFn func([]byte) (*canonical.CanonicalEmbeddingRequest, error)
|
||||
decodeRerankReqFn func([]byte) (*canonical.CanonicalRerankRequest, error)
|
||||
}
|
||||
|
||||
func newMockAdapter(name string, passthrough bool) *mockProtocolAdapter {
|
||||
return &mockProtocolAdapter{
|
||||
protocolName: name,
|
||||
passthrough: passthrough,
|
||||
ifaceType: InterfaceTypeChat,
|
||||
supportsIface: map[InterfaceType]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ProtocolName() string { return m.protocolName }
|
||||
func (m *mockProtocolAdapter) ProtocolVersion() string { return "1.0" }
|
||||
func (m *mockProtocolAdapter) SupportsPassthrough() bool { return m.passthrough }
|
||||
|
||||
func (m *mockProtocolAdapter) DetectInterfaceType(nativePath string) InterfaceType {
|
||||
return m.ifaceType
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
|
||||
return nativePath
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) BuildHeaders(provider *TargetProvider) map[string]string {
|
||||
return map[string]string{"Authorization": "Bearer " + provider.APIKey}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) SupportsInterface(interfaceType InterfaceType) bool {
|
||||
if v, ok := m.supportsIface[interfaceType]; ok {
|
||||
return v
|
||||
}
|
||||
return interfaceType == InterfaceTypeChat
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
if m.decodeReqFn != nil {
|
||||
return m.decodeReqFn(raw)
|
||||
}
|
||||
req := &canonical.CanonicalRequest{}
|
||||
_ = json.Unmarshal(raw, req)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeRequest(req *canonical.CanonicalRequest, provider *TargetProvider) ([]byte, error) {
|
||||
if m.encodeReqFn != nil {
|
||||
return m.encodeReqFn(req, provider)
|
||||
}
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
if m.decodeRespFn != nil {
|
||||
return m.decodeRespFn(raw)
|
||||
}
|
||||
resp := &canonical.CanonicalResponse{}
|
||||
_ = json.Unmarshal(raw, resp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
if m.encodeRespFn != nil {
|
||||
return m.encodeRespFn(resp)
|
||||
}
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) CreateStreamDecoder() StreamDecoder {
|
||||
if m.streamDecoderFn != nil {
|
||||
return m.streamDecoderFn()
|
||||
}
|
||||
return &noopStreamDecoder{}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) CreateStreamEncoder() StreamEncoder {
|
||||
if m.streamEncoderFn != nil {
|
||||
return m.streamEncoderFn()
|
||||
}
|
||||
return &noopStreamEncoder{}
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeError(err *ConversionError) ([]byte, int) {
|
||||
return []byte(`{"error":"mock"}`), 400
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return &canonical.CanonicalModelList{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return json.Marshal(list)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return &canonical.CanonicalModelInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return json.Marshal(info)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
if m.decodeEmbeddingReqFn != nil {
|
||||
return m.decodeEmbeddingReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *TargetProvider) ([]byte, error) {
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return &canonical.CanonicalEmbeddingResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
if m.decodeRerankReqFn != nil {
|
||||
return m.decodeRerankReqFn(raw)
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *TargetProvider) ([]byte, error) {
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return &canonical.CanonicalRerankResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return json.Marshal(resp)
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) ExtractModelName(body []byte, ifaceType InterfaceType) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteRequestModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteReqFn != nil {
|
||||
return m.rewriteReqFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
if m.rewriteRespFn != nil {
|
||||
return m.rewriteRespFn(body, newModel, ifaceType)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// noopStreamDecoder 空流式解码器
|
||||
type noopStreamDecoder struct{}
|
||||
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
|
||||
// noopStreamEncoder 空流式编码器
|
||||
type noopStreamEncoder struct{}
|
||||
|
||||
func (e *noopStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte { return nil }
|
||||
func (e *noopStreamEncoder) Flush() [][]byte { return nil }
|
||||
|
||||
// ============ 测试用例 ============
|
||||
|
||||
func TestNewConversionEngine(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine)
|
||||
assert.Equal(t, registry, engine.GetRegistry())
|
||||
}
|
||||
|
||||
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||
t.Run("nil_logger_uses_global", func(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine.logger)
|
||||
})
|
||||
|
||||
t.Run("custom_logger", func(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
customLogger := zap.NewNop()
|
||||
engine := NewConversionEngine(registry, customLogger)
|
||||
assert.NotNil(t, engine.logger)
|
||||
assert.Contains(t, engine.logger.Name(), "conversion.engine")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterAdapter(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
adapter := newMockAdapter("test-proto", true)
|
||||
err := engine.RegisterAdapter(adapter)
|
||||
require.NoError(t, err)
|
||||
|
||||
protocols := registry.ListProtocols()
|
||||
assert.Contains(t, protocols, "test-proto")
|
||||
}
|
||||
|
||||
func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("openai", true)
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
assert.True(t, engine.IsPassthrough("openai", "openai"))
|
||||
}
|
||||
|
||||
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
|
||||
|
||||
assert.False(t, engine.IsPassthrough("openai", "anthropic"))
|
||||
}
|
||||
|
||||
func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
|
||||
|
||||
assert.False(t, engine.IsPassthrough("custom", "custom"))
|
||||
}
|
||||
|
||||
func TestDetectInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("test", true)
|
||||
adapter.ifaceType = InterfaceTypeChat
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
ifaceType, err := engine.DetectInterfaceType("/chat/completions", "test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, InterfaceTypeChat, ifaceType)
|
||||
}
|
||||
|
||||
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
openaiAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/chat/completions"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||
assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client-proto", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: "test-model",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}, nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
providerAdapter := newMockAdapter("provider-proto", false)
|
||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{"model": p.ModelName})
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
provider := NewTargetProvider("https://example.com", "key", "my-model")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"test"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "client-proto", "provider-proto", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result.URL, "https://example.com")
|
||||
assert.NotNil(t, result.Body)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
openaiAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/chat/completions"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
return []byte(`{"model":"` + newModel + `"}`), nil
|
||||
}
|
||||
require.NoError(t, registry.Register(openaiAdapter))
|
||||
|
||||
anthropicAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("anthropic", false),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/v1/messages"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
anthropicAdapter.ifaceType = InterfaceTypeChat
|
||||
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
require.NoError(t, registry.Register(anthropicAdapter))
|
||||
|
||||
t.Run("OpenAI to Anthropic", func(t *testing.T) {
|
||||
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
|
||||
})
|
||||
|
||||
t.Run("Anthropic to OpenAI", func(t *testing.T) {
|
||||
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/messages",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||
})
|
||||
}
|
||||
|
||||
type buildURLMockAdapter struct {
|
||||
*mockProtocolAdapter
|
||||
buildURLFn func(string, InterfaceType) string
|
||||
}
|
||||
|
||||
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
|
||||
if m.buildURLFn != nil {
|
||||
return m.buildURLFn(nativePath, interfaceType)
|
||||
}
|
||||
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"id":"123"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, result.StatusCode)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*PassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
body, statusCode, err := engine.EncodeError(convErr, "openai")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 400, statusCode)
|
||||
assert.NotNil(t, body)
|
||||
}
|
||||
|
||||
func TestEncodeError_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 500, statusCode)
|
||||
assert.Contains(t, string(body), "测试错误")
|
||||
}
|
||||
|
||||
func TestRegistry_DuplicateRegistration(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
adapter := newMockAdapter("openai", true)
|
||||
|
||||
err := registry.Register(adapter)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = registry.Register(adapter)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "适配器已注册")
|
||||
}
|
||||
|
||||
func TestRegistry_GetNonExistent(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
|
||||
_, err := registry.Get("nonexistent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "未找到适配器")
|
||||
}
|
||||
|
||||
// ============ modelOverride 测试 ============
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return json.Marshal(map[string]any{"model": resp.Model})
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return &canonical.CanonicalResponse{ID: "test", Model: "native-model", Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}, nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"model":"native-model"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "client", "provider", InterfaceTypeChat, "provider/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "provider/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
StatusCode: 200,
|
||||
Body: []byte(`{"id":"resp-1","model":"gpt-4"}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpResponse(spec, "openai", "openai", InterfaceTypeChat, "openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(result.Body, &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
assert.Equal(t, "resp-1", resp["id"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
return json.Marshal(m)
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "openai/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 验证 SSE frame 中的 data JSON 被改写
|
||||
chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]interface{}
|
||||
payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
|
||||
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// provider adapter 解码出含 model 的流式事件
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: "hi"}),
|
||||
canonical.NewMessageStopEvent(),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
// client adapter 编码时输出 model 字段
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"type": string(event.Type),
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
data, _ := json.Marshal(map[string]string{"type": string(event.Type)})
|
||||
return [][]byte{data}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "provider/gpt-4", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证类型是 CanonicalStreamConverter
|
||||
_, ok := converter.(*CanonicalStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 处理一个 chunk,验证 model 被覆写为统一模型 ID
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 3) // message_start + content_block_start + message_stop
|
||||
|
||||
var startEvent map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &startEvent))
|
||||
assert.Equal(t, "provider/gpt-4", startEvent["model"], "跨协议流式中 modelOverride 应覆写 Message.Model")
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
return &engineTestStreamDecoder{
|
||||
processFn: func(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
return []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStartEvent("msg-1", "native-model"),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(providerAdapter)
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.streamEncoderFn = func() StreamEncoder {
|
||||
return &engineTestStreamEncoder{
|
||||
encodeFn: func(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Message != nil {
|
||||
data, _ := json.Marshal(map[string]string{
|
||||
"model": event.Message.Model,
|
||||
})
|
||||
return [][]byte{data}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
_ = engine.RegisterAdapter(clientAdapter)
|
||||
|
||||
// modelOverride 为空,不应覆写
|
||||
converter, err := engine.CreateStreamConverter("client", "provider", "", InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunks := converter.ProcessChunk([]byte("raw"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]string
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
assert.Equal(t, "native-model", resp["model"], "modelOverride 为空时不应覆写")
|
||||
}
|
||||
|
||||
// engineTestStreamDecoder 可控的流式解码器(用于 engine_test)
|
||||
type engineTestStreamDecoder struct {
|
||||
processFn func([]byte) []canonical.CanonicalStreamEvent
|
||||
flushFn func() []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.CanonicalStreamEvent {
|
||||
if d.processFn != nil {
|
||||
return d.processFn(raw)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
if d.flushFn != nil {
|
||||
return d.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// engineTestStreamEncoder 可控的流式编码器(用于 engine_test)
|
||||
type engineTestStreamEncoder struct {
|
||||
encodeFn func(canonical.CanonicalStreamEvent) [][]byte
|
||||
flushFn func() [][]byte
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if e.encodeFn != nil {
|
||||
return e.encodeFn(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||
if e.flushFn != nil {
|
||||
return e.flushFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
104
backend/internal/conversion/errors.go
Normal file
104
backend/internal/conversion/errors.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package conversion
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ErrorCode 错误码枚举
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
||||
ErrorCodeMissingRequiredField ErrorCode = "MISSING_REQUIRED_FIELD"
|
||||
ErrorCodeIncompatibleFeature ErrorCode = "INCOMPATIBLE_FEATURE"
|
||||
ErrorCodeFieldMappingFailure ErrorCode = "FIELD_MAPPING_FAILURE"
|
||||
ErrorCodeToolCallParseError ErrorCode = "TOOL_CALL_PARSE_ERROR"
|
||||
ErrorCodeJSONParseError ErrorCode = "JSON_PARSE_ERROR"
|
||||
ErrorCodeStreamStateError ErrorCode = "STREAM_STATE_ERROR"
|
||||
ErrorCodeUTF8DecodeError ErrorCode = "UTF8_DECODE_ERROR"
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrorDetailPhase = "phase"
|
||||
ErrorPhaseRequest = "request"
|
||||
ErrorPhaseResponse = "response"
|
||||
)
|
||||
|
||||
// ConversionError 协议转换错误
|
||||
type ConversionError struct {
|
||||
Code ErrorCode
|
||||
Message string
|
||||
ClientProtocol string
|
||||
ProviderProtocol string
|
||||
InterfaceType string
|
||||
Details map[string]any
|
||||
Cause error
|
||||
}
|
||||
|
||||
// NewConversionError 创建转换错误
|
||||
func NewConversionError(code ErrorCode, message string) *ConversionError {
|
||||
return &ConversionError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// NewRequestJSONParseError 创建请求 JSON 解析错误。
|
||||
func NewRequestJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// NewResponseJSONParseError 创建响应 JSON 解析错误。
|
||||
func NewResponseJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// WithClientProtocol 设置客户端协议
|
||||
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
|
||||
e.ClientProtocol = protocol
|
||||
return e
|
||||
}
|
||||
|
||||
// WithProviderProtocol 设置服务端协议
|
||||
func (e *ConversionError) WithProviderProtocol(protocol string) *ConversionError {
|
||||
e.ProviderProtocol = protocol
|
||||
return e
|
||||
}
|
||||
|
||||
// WithInterfaceType 设置接口类型
|
||||
func (e *ConversionError) WithInterfaceType(ifaceType string) *ConversionError {
|
||||
e.InterfaceType = ifaceType
|
||||
return e
|
||||
}
|
||||
|
||||
// WithDetail 添加详情
|
||||
func (e *ConversionError) WithDetail(key string, value any) *ConversionError {
|
||||
e.Details[key] = value
|
||||
return e
|
||||
}
|
||||
|
||||
// WithCause 设置原因
|
||||
func (e *ConversionError) WithCause(cause error) *ConversionError {
|
||||
e.Cause = cause
|
||||
return e
|
||||
}
|
||||
|
||||
// Error 实现 error 接口
|
||||
func (e *ConversionError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("[%s] %s: %v", e.Code, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap 支持 errors.Is/As
|
||||
func (e *ConversionError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
45
backend/internal/conversion/errors_test.go
Normal file
45
backend/internal/conversion/errors_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConversionError_Builder(t *testing.T) {
|
||||
cause := errors.New("原始错误")
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "输入无效").
|
||||
WithClientProtocol("openai").
|
||||
WithDetail("field", "model").
|
||||
WithCause(cause)
|
||||
|
||||
assert.Equal(t, ErrorCodeInvalidInput, err.Code)
|
||||
assert.Equal(t, "openai", err.ClientProtocol)
|
||||
assert.Equal(t, "输入无效", err.Message)
|
||||
assert.Equal(t, "model", err.Details["field"])
|
||||
assert.Equal(t, cause, err.Cause)
|
||||
}
|
||||
|
||||
func TestConversionError_Unwrap(t *testing.T) {
|
||||
cause := errors.New("根本原因")
|
||||
err := NewConversionError(ErrorCodeJSONParseError, "解析失败").WithCause(cause)
|
||||
|
||||
unwrapped := err.Unwrap()
|
||||
assert.Equal(t, cause, unwrapped)
|
||||
}
|
||||
|
||||
func TestConversionError_Error_WithCause(t *testing.T) {
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "输入无效").WithCause(errors.New("原因"))
|
||||
msg := err.Error()
|
||||
assert.Contains(t, msg, "INVALID_INPUT")
|
||||
assert.Contains(t, msg, "输入无效")
|
||||
assert.Contains(t, msg, "原因")
|
||||
}
|
||||
|
||||
func TestConversionError_Error_WithoutCause(t *testing.T) {
|
||||
err := NewConversionError(ErrorCodeInvalidInput, "输入无效")
|
||||
msg := err.Error()
|
||||
assert.Contains(t, msg, "INVALID_INPUT")
|
||||
assert.Contains(t, msg, "输入无效")
|
||||
}
|
||||
13
backend/internal/conversion/interface.go
Normal file
13
backend/internal/conversion/interface.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package conversion
|
||||
|
||||
// InterfaceType 接口类型枚举
|
||||
type InterfaceType string
|
||||
|
||||
const (
|
||||
InterfaceTypeChat InterfaceType = "CHAT"
|
||||
InterfaceTypeModels InterfaceType = "MODELS"
|
||||
InterfaceTypeModelInfo InterfaceType = "MODEL_INFO"
|
||||
InterfaceTypeEmbeddings InterfaceType = "EMBEDDINGS"
|
||||
InterfaceTypeRerank InterfaceType = "RERANK"
|
||||
InterfaceTypePassthrough InterfaceType = "PASSTHROUGH"
|
||||
)
|
||||
76
backend/internal/conversion/middleware.go
Normal file
76
backend/internal/conversion/middleware.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ConversionMiddleware 转换中间件接口
|
||||
type ConversionMiddleware interface {
|
||||
Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error)
|
||||
InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error)
|
||||
}
|
||||
|
||||
// ConversionContext 转换上下文
|
||||
type ConversionContext struct {
|
||||
ConversionID string
|
||||
InterfaceType InterfaceType
|
||||
Timestamp time.Time
|
||||
Metadata map[string]any
|
||||
}
|
||||
|
||||
// NewConversionContext 创建转换上下文
|
||||
func NewConversionContext(ifaceType InterfaceType) *ConversionContext {
|
||||
return &ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: ifaceType,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Metadata: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// MiddlewareChain 中间件链
|
||||
type MiddlewareChain struct {
|
||||
middlewares []ConversionMiddleware
|
||||
}
|
||||
|
||||
// NewMiddlewareChain 创建中间件链
|
||||
func NewMiddlewareChain() *MiddlewareChain {
|
||||
return &MiddlewareChain{
|
||||
middlewares: make([]ConversionMiddleware, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Use 添加中间件
|
||||
func (c *MiddlewareChain) Use(mw ConversionMiddleware) {
|
||||
c.middlewares = append(c.middlewares, mw)
|
||||
}
|
||||
|
||||
// Apply 对请求按顺序执行所有中间件
|
||||
func (c *MiddlewareChain) Apply(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
result := req
|
||||
for _, mw := range c.middlewares {
|
||||
var err error
|
||||
result, err = mw.Intercept(result, clientProtocol, providerProtocol, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ApplyStreamEvent 对流式事件按顺序执行所有中间件
|
||||
func (c *MiddlewareChain) ApplyStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
|
||||
result := event
|
||||
for _, mw := range c.middlewares {
|
||||
var err error
|
||||
result, err = mw.InterceptStreamEvent(result, clientProtocol, providerProtocol, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
85
backend/internal/conversion/middleware_test.go
Normal file
85
backend/internal/conversion/middleware_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// recordingMiddleware 记录调用顺序的中间件
|
||||
type recordingMiddleware struct {
|
||||
name string
|
||||
records *[]string
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *recordingMiddleware) Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
*m.records = append(*m.records, m.name)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *recordingMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
|
||||
*m.records = append(*m.records, "stream:"+m.name)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_Empty(t *testing.T) {
|
||||
chain := NewMiddlewareChain()
|
||||
req := &canonical.CanonicalRequest{Model: "test"}
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
result, err := chain.Apply(req, "a", "b", ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", result.Model)
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_Order(t *testing.T) {
|
||||
var records []string
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "first", records: &records})
|
||||
chain.Use(&recordingMiddleware{name: "second", records: &records})
|
||||
chain.Use(&recordingMiddleware{name: "third", records: &records})
|
||||
|
||||
req := &canonical.CanonicalRequest{Model: "test"}
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
_, err := chain.Apply(req, "a", "b", ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"first", "second", "third"}, records)
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_ErrorInterrupt(t *testing.T) {
|
||||
var records []string
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "first", records: &records})
|
||||
chain.Use(&recordingMiddleware{name: "second", records: &records, err: errors.New("中断")})
|
||||
chain.Use(&recordingMiddleware{name: "third", records: &records})
|
||||
|
||||
req := &canonical.CanonicalRequest{Model: "test"}
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
_, err := chain.Apply(req, "a", "b", ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "中断", err.Error())
|
||||
assert.Equal(t, []string{"first", "second"}, records)
|
||||
}
|
||||
|
||||
func TestMiddlewareChain_ApplyStreamEvent(t *testing.T) {
|
||||
var records []string
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||
|
||||
event := canonical.NewMessageStartEvent("id", "model")
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
result, err := chain.ApplyStreamEvent(&event, "a", "b", ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, canonical.EventMessageStart, result.Type)
|
||||
assert.Equal(t, []string{"stream:mw1"}, records)
|
||||
}
|
||||
316
backend/internal/conversion/openai/adapter.go
Normal file
316
backend/internal/conversion/openai/adapter.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// Adapter OpenAI 协议适配器
|
||||
type Adapter struct{}
|
||||
|
||||
// NewAdapter 创建 OpenAI 适配器
|
||||
func NewAdapter() *Adapter {
|
||||
return &Adapter{}
|
||||
}
|
||||
|
||||
// ProtocolName 返回协议名称
|
||||
func (a *Adapter) ProtocolName() string { return "openai" }
|
||||
|
||||
// ProtocolVersion 返回协议版本
|
||||
func (a *Adapter) ProtocolVersion() string { return "" }
|
||||
|
||||
// SupportsPassthrough 支持同协议透传
|
||||
func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/v1/chat/completions":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/v1/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case isModelInfoPath(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
case nativePath == "/v1/embeddings":
|
||||
return conversion.InterfaceTypeEmbeddings
|
||||
case nativePath == "/v1/rerank":
|
||||
return conversion.InterfaceTypeRerank
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
// BuildUrl 根据接口类型构建 URL
|
||||
func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.InterfaceType) string {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
return "/chat/completions"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/models"
|
||||
case conversion.InterfaceTypeModelInfo:
|
||||
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
|
||||
return "/models/" + modelID
|
||||
}
|
||||
return nativePath
|
||||
case conversion.InterfaceTypeEmbeddings:
|
||||
return "/embeddings"
|
||||
case conversion.InterfaceTypeRerank:
|
||||
return "/rerank"
|
||||
default:
|
||||
return nativePath
|
||||
}
|
||||
}
|
||||
|
||||
// BuildHeaders 构建请求头
|
||||
func (a *Adapter) BuildHeaders(provider *conversion.TargetProvider) map[string]string {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + provider.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if org, ok := provider.AdapterConfig["organization"].(string); ok && org != "" {
|
||||
headers["OpenAI-Organization"] = org
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsInterface 检查是否支持接口类型
|
||||
func (a *Adapter) SupportsInterface(interfaceType conversion.InterfaceType) bool {
|
||||
switch interfaceType {
|
||||
case conversion.InterfaceTypeChat,
|
||||
conversion.InterfaceTypeModels,
|
||||
conversion.InterfaceTypeModelInfo,
|
||||
conversion.InterfaceTypeEmbeddings,
|
||||
conversion.InterfaceTypeRerank:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeRequest 解码请求
|
||||
func (a *Adapter) DecodeRequest(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return decodeRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRequest 编码请求
|
||||
func (a *Adapter) EncodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeResponse 解码响应
|
||||
func (a *Adapter) DecodeResponse(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return decodeResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeResponse 编码响应
|
||||
func (a *Adapter) EncodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
return encodeResponse(resp)
|
||||
}
|
||||
|
||||
// CreateStreamDecoder 创建流式解码器
|
||||
func (a *Adapter) CreateStreamDecoder() conversion.StreamDecoder {
|
||||
return NewStreamDecoder()
|
||||
}
|
||||
|
||||
// CreateStreamEncoder 创建流式编码器
|
||||
func (a *Adapter) CreateStreamEncoder() conversion.StreamEncoder {
|
||||
return NewStreamEncoder()
|
||||
}
|
||||
|
||||
// EncodeError 编码错误
|
||||
func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
errType := mapErrorCode(err.Code)
|
||||
statusCode := 500
|
||||
|
||||
errMsg := ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: err.Message,
|
||||
Type: errType,
|
||||
Param: nil,
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
// mapErrorCode 映射错误码到 OpenAI 错误类型
|
||||
func mapErrorCode(code conversion.ErrorCode) string {
|
||||
switch code {
|
||||
case conversion.ErrorCodeInvalidInput,
|
||||
conversion.ErrorCodeMissingRequiredField,
|
||||
conversion.ErrorCodeIncompatibleFeature,
|
||||
conversion.ErrorCodeToolCallParseError,
|
||||
conversion.ErrorCodeJSONParseError,
|
||||
conversion.ErrorCodeProtocolConstraint,
|
||||
conversion.ErrorCodeFieldMappingFailure:
|
||||
return "invalid_request_error"
|
||||
default:
|
||||
return "server_error"
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeModelsResponse 解码模型列表响应
|
||||
func (a *Adapter) DecodeModelsResponse(raw []byte) (*canonical.CanonicalModelList, error) {
|
||||
return decodeModelsResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelsResponse 编码模型列表响应
|
||||
func (a *Adapter) EncodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
return encodeModelsResponse(list)
|
||||
}
|
||||
|
||||
// DecodeModelInfoResponse 解码模型详情响应
|
||||
func (a *Adapter) DecodeModelInfoResponse(raw []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
return decodeModelInfoResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeModelInfoResponse 编码模型详情响应
|
||||
func (a *Adapter) EncodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
return encodeModelInfoResponse(info)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingRequest 解码嵌入请求
|
||||
func (a *Adapter) DecodeEmbeddingRequest(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
return decodeEmbeddingRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeEmbeddingRequest 编码嵌入请求
|
||||
func (a *Adapter) EncodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeEmbeddingRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeEmbeddingResponse 解码嵌入响应
|
||||
func (a *Adapter) DecodeEmbeddingResponse(raw []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
return decodeEmbeddingResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeEmbeddingResponse 编码嵌入响应
|
||||
func (a *Adapter) EncodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
return encodeEmbeddingResponse(resp)
|
||||
}
|
||||
|
||||
// DecodeRerankRequest 解码重排序请求
|
||||
func (a *Adapter) DecodeRerankRequest(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
return decodeRerankRequest(raw)
|
||||
}
|
||||
|
||||
// EncodeRerankRequest 编码重排序请求
|
||||
func (a *Adapter) EncodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
return encodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
// DecodeRerankResponse 解码重排序响应
|
||||
func (a *Adapter) DecodeRerankResponse(raw []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
return decodeRerankResponse(raw)
|
||||
}
|
||||
|
||||
// EncodeRerankResponse 编码重排序响应
|
||||
func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
return encodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
return suffix, nil
|
||||
}
|
||||
|
||||
// locateModelFieldInRequest 定位请求体中 model 字段的值并提供改写函数
|
||||
func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType) (string, func(string) ([]byte, error), error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
raw, exists := m["model"]
|
||||
if !exists {
|
||||
return "", nil, fmt.Errorf("请求体中缺少 model 字段")
|
||||
}
|
||||
var current string
|
||||
if err := json.Unmarshal(raw, ¤t); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("不支持的接口类型: %s", ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractModelName 从请求体中提取 model 值
|
||||
func (a *Adapter) ExtractModelName(body []byte, ifaceType conversion.InterfaceType) (string, error) {
|
||||
model, _, err := locateModelFieldInRequest(body, ifaceType)
|
||||
return model, err
|
||||
}
|
||||
|
||||
// RewriteRequestModelName 最小化改写请求体中的 model 字段
|
||||
func (a *Adapter) RewriteRequestModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
_, rewriteFunc, err := locateModelFieldInRequest(body, ifaceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rewriteFunc(newModel)
|
||||
}
|
||||
|
||||
// RewriteResponseModelName 最小化改写响应体中的 model 字段
|
||||
func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceType conversion.InterfaceType) ([]byte, error) {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
case conversion.InterfaceTypeRerank:
|
||||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||
if _, exists := m["model"]; exists {
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
}
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
205
backend/internal/conversion/openai/adapter_test.go
Normal file
205
backend/internal/conversion/openai/adapter_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdapter_ProtocolName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.Equal(t, "openai", a.ProtocolName())
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsPassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
assert.True(t, a.SupportsPassthrough())
|
||||
}
|
||||
|
||||
func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
|
||||
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.DetectInterfaceType(tt.path)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/chat/completions", conversion.InterfaceTypePassthrough},
|
||||
{"/models", conversion.InterfaceTypePassthrough},
|
||||
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
|
||||
{"/embeddings", conversion.InterfaceTypePassthrough},
|
||||
{"/rerank", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nativePath string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
|
||||
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
|
||||
{"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
|
||||
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
|
||||
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
|
||||
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.BuildUrl(tt.nativePath, tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildHeaders(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("基本头", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "Bearer sk-test123", headers["Authorization"])
|
||||
assert.Equal(t, "application/json", headers["Content-Type"])
|
||||
_, hasOrg := headers["OpenAI-Organization"]
|
||||
assert.False(t, hasOrg)
|
||||
})
|
||||
|
||||
t.Run("带组织", func(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.openai.com", "sk-test123", "gpt-4")
|
||||
provider.AdapterConfig["organization"] = "org-abc"
|
||||
headers := a.BuildHeaders(provider)
|
||||
assert.Equal(t, "org-abc", headers["OpenAI-Organization"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_SupportsInterface(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
interfaceType conversion.InterfaceType
|
||||
expected bool
|
||||
}{
|
||||
{"聊天", conversion.InterfaceTypeChat, true},
|
||||
{"模型", conversion.InterfaceTypeModels, true},
|
||||
{"模型详情", conversion.InterfaceTypeModelInfo, true},
|
||||
{"嵌入", conversion.InterfaceTypeEmbeddings, true},
|
||||
{"重排序", conversion.InterfaceTypeRerank, true},
|
||||
{"透传", conversion.InterfaceTypePassthrough, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := a.SupportsInterface(tt.interfaceType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsModelInfoPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"model_info", "/v1/models/openai/gpt-4", true},
|
||||
{"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
|
||||
{"empty_suffix", "/v1/models/", false},
|
||||
{"unrelated", "/v1/chat/completions", false},
|
||||
{"partial_prefix", "/model", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("标准路径", func(t *testing.T) {
|
||||
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", modelID)
|
||||
})
|
||||
|
||||
t.Run("复杂路径", func(t *testing.T) {
|
||||
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
|
||||
})
|
||||
|
||||
t.Run("非模型详情路径报错", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "参数无效", resp.Error.Message)
|
||||
assert.Equal(t, "invalid_request_error", resp.Error.Type)
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_ServerError(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeStreamStateError, "流状态错误")
|
||||
|
||||
body, statusCode := a.EncodeError(convErr)
|
||||
require.Equal(t, 500, statusCode)
|
||||
|
||||
var resp ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(body, &resp))
|
||||
assert.Equal(t, "server_error", resp.Error.Type)
|
||||
assert.Equal(t, "流状态错误", resp.Error.Message)
|
||||
}
|
||||
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
360
backend/internal/conversion/openai/adapter_unified_test.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractUnifiedModelID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unrelated_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/other")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", model)
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", model)
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
model, err := a.ExtractModelName(body, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", model)
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.ExtractModelName(body, conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteRequestModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteRequestModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","messages":[]}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
|
||||
// messages field preserved
|
||||
msgs, ok := m["messages"]
|
||||
require.True(t, ok)
|
||||
msgsArr, ok := msgs.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgsArr, 0)
|
||||
})
|
||||
|
||||
t.Run("preserves_unknown_fields", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "gpt-4", m["model"])
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "text-embedding", m["model"])
|
||||
assert.Equal(t, "hello", m["input"])
|
||||
})
|
||||
|
||||
t.Run("rerank", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
rewritten, err := a.RewriteRequestModelName(body, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "rerank", m["model"])
|
||||
assert.Equal(t, "test", m["query"])
|
||||
})
|
||||
|
||||
t.Run("no_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[]}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("unsupported_interface_type", func(t *testing.T) {
|
||||
body := []byte(`{"model":"openai/gpt-4"}`)
|
||||
_, err := a.RewriteRequestModelName(body, "gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RewriteResponseModelName
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestRewriteResponseModelName(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("chat_without_model_field", func(t *testing.T) {
|
||||
body := []byte(`{"choices":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/gpt-4", m["model"])
|
||||
|
||||
choices, ok := m["choices"]
|
||||
require.True(t, ok)
|
||||
choicesArr, ok := choices.([]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Len(t, choicesArr, 0)
|
||||
})
|
||||
|
||||
t.Run("rerank_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"rerank","results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/rerank", m["model"])
|
||||
})
|
||||
|
||||
t.Run("rerank_without_model_field_should_not_add", func(t *testing.T) {
|
||||
body := []byte(`{"results":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
_, hasModel := m["model"]
|
||||
assert.False(t, hasModel, "rerank response without model field should not have one added")
|
||||
})
|
||||
|
||||
t.Run("embedding_existing_model", func(t *testing.T) {
|
||||
body := []byte(`{"model":"text-embedding","data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("embedding_without_model_field_adds", func(t *testing.T) {
|
||||
body := []byte(`{"data":[]}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, "openai/text-embedding", m["model"])
|
||||
})
|
||||
|
||||
t.Run("passthrough_returns_body_unchanged", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
rewritten, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypePassthrough)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(body), string(rewritten))
|
||||
})
|
||||
|
||||
t.Run("invalid_json", func(t *testing.T) {
|
||||
body := []byte(`{invalid}`)
|
||||
_, err := a.RewriteResponseModelName(body, "openai/gpt-4", conversion.InterfaceTypeChat)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ExtractModelName and RewriteRequest consistency
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractModelNameAndRewriteRequestConsistency(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("chat_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/gpt-4","messages":[],"temperature":0.7}`)
|
||||
|
||||
// Extract the unified model ID from the body
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", extracted)
|
||||
|
||||
// Rewrite to the native model name
|
||||
rewritten, err := a.RewriteRequestModelName(original, "gpt-4", conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract again from the rewritten body to verify the same location was targeted
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeChat)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", afterRewrite)
|
||||
|
||||
// Verify other fields are preserved
|
||||
var m map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(rewritten, &m))
|
||||
assert.Equal(t, 0.7, m["temperature"])
|
||||
})
|
||||
|
||||
t.Run("embedding_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/text-embedding","input":"hello"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/text-embedding", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "text-embedding", conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeEmbeddings)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "text-embedding", afterRewrite)
|
||||
})
|
||||
|
||||
t.Run("rerank_round_trip", func(t *testing.T) {
|
||||
original := []byte(`{"model":"openai/rerank","query":"test"}`)
|
||||
|
||||
extracted, err := a.ExtractModelName(original, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/rerank", extracted)
|
||||
|
||||
rewritten, err := a.RewriteRequestModelName(original, "rerank", conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
|
||||
afterRewrite, err := a.ExtractModelName(rewritten, conversion.InterfaceTypeRerank)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rerank", afterRewrite)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isModelInfoPath (additional unified model ID cases)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/v1/models/gpt-4", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"chat_completions", "/v1/chat/completions", false},
|
||||
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isModelInfoPath(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
696
backend/internal/conversion/openai/decoder.go
Normal file
696
backend/internal/conversion/openai/decoder.go
Normal file
@@ -0,0 +1,696 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// decodeRequest 将 OpenAI 请求解码为 Canonical 请求
|
||||
func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
var req ChatCompletionRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 请求失败").WithCause(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "model 字段不能为空")
|
||||
}
|
||||
if len(req.Messages) == 0 {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "messages 字段不能为空")
|
||||
}
|
||||
|
||||
// 废弃字段兼容
|
||||
decodeDeprecatedFields(&req)
|
||||
|
||||
system, messages := decodeSystemPrompt(req.Messages)
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range messages {
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
tools := decodeTools(req.Tools)
|
||||
toolChoice := decodeToolChoice(req.ToolChoice)
|
||||
params := decodeParameters(&req)
|
||||
outputFormat := decodeOutputFormat(req.ResponseFormat)
|
||||
thinking := decodeThinking(req.ReasoningEffort)
|
||||
|
||||
var parallelToolUse *bool
|
||||
if req.ParallelToolCalls != nil {
|
||||
parallelToolUse = req.ParallelToolCalls
|
||||
}
|
||||
|
||||
return &canonical.CanonicalRequest{
|
||||
Model: req.Model,
|
||||
System: system,
|
||||
Messages: canonicalMsgs,
|
||||
Tools: tools,
|
||||
ToolChoice: toolChoice,
|
||||
Parameters: params,
|
||||
Thinking: thinking,
|
||||
Stream: req.Stream,
|
||||
UserID: req.User,
|
||||
OutputFormat: outputFormat,
|
||||
ParallelToolUse: parallelToolUse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeSystemPrompt 提取 system 和 developer 消息
|
||||
func decodeSystemPrompt(messages []Message) (any, []Message) {
|
||||
var systemParts []string
|
||||
var remaining []Message
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" || msg.Role == "developer" {
|
||||
text := extractText(msg.Content)
|
||||
if text != "" {
|
||||
systemParts = append(systemParts, text)
|
||||
}
|
||||
} else {
|
||||
remaining = append(remaining, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if len(systemParts) == 0 {
|
||||
return nil, remaining
|
||||
}
|
||||
return strings.Join(systemParts, "\n\n"), remaining
|
||||
}
|
||||
|
||||
// extractText 从 content 提取文本
|
||||
func extractText(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
var parts []string
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t, ok := m["type"].(string); ok && t == "text" {
|
||||
if text, ok := m["text"].(string); ok {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeMessage 解码 OpenAI 消息
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeUserContent(msg.Content)
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: blocks}}, nil
|
||||
|
||||
case "assistant":
|
||||
var blocks []canonical.ContentBlock
|
||||
// 处理 content
|
||||
if msg.Content != nil {
|
||||
switch v := msg.Content.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(v))
|
||||
}
|
||||
default:
|
||||
parts := decodeContentParts(msg.Content)
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(p.Text))
|
||||
} else if p.Type == "refusal" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(p.Refusal))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// refusal 顶层字段
|
||||
if msg.Refusal != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(msg.Refusal))
|
||||
}
|
||||
// reasoning_content 非标准字段
|
||||
if msg.ReasoningContent != "" {
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(msg.ReasoningContent))
|
||||
}
|
||||
// tool_calls
|
||||
for _, tc := range msg.ToolCalls {
|
||||
var input json.RawMessage
|
||||
if tc.Type == "custom" && tc.Custom != nil {
|
||||
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
|
||||
} else if tc.Function != nil {
|
||||
parsed := json.RawMessage(tc.Function.Arguments)
|
||||
if !json.Valid(parsed) {
|
||||
parsed = json.RawMessage("{}")
|
||||
}
|
||||
input = parsed
|
||||
} else {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
name := ""
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
} else if tc.Custom != nil {
|
||||
name = tc.Custom.Name
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
|
||||
}
|
||||
// 已废弃 function_call
|
||||
if msg.FunctionCall != nil {
|
||||
input := json.RawMessage(msg.FunctionCall.Arguments)
|
||||
if !json.Valid(input) {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(generateID(), msg.FunctionCall.Name, input))
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
|
||||
case "tool":
|
||||
content := extractText(msg.Content)
|
||||
isErr := false
|
||||
block := canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: msg.ToolCallID,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: &isErr,
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
|
||||
|
||||
case "function":
|
||||
content := extractText(msg.Content)
|
||||
isErr := false
|
||||
block := canonical.ContentBlock{
|
||||
Type: "tool_result",
|
||||
ToolUseID: msg.Name,
|
||||
Content: json.RawMessage(fmt.Sprintf("%q", content)),
|
||||
IsError: &isErr,
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleTool, Content: []canonical.ContentBlock{block}}}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeUserContent 解码用户内容
|
||||
func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
case "image_url":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||
case "input_audio":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "audio"})
|
||||
case "file":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "file"})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
}
|
||||
}
|
||||
|
||||
// contentPart 内容部分
|
||||
type contentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Refusal string
|
||||
}
|
||||
|
||||
// decodeContentParts 解码内容部分
|
||||
func decodeContentParts(content any) []contentPart {
|
||||
parts, ok := content.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var result []contentPart
|
||||
for _, item := range parts {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "text", Text: text})
|
||||
case "refusal":
|
||||
refusal, ok := m["refusal"].(string)
|
||||
if !ok {
|
||||
refusal = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
func decodeTools(tools []Tool) []canonical.CanonicalTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
var result []canonical.CanonicalTool
|
||||
for _, tool := range tools {
|
||||
if tool.Type == "function" && tool.Function != nil {
|
||||
result = append(result, canonical.CanonicalTool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
InputSchema: tool.Function.Parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeToolChoice 解码工具选择
|
||||
func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
if toolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
switch v := toolChoice.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
case "none":
|
||||
return canonical.NewToolChoiceNone()
|
||||
case "required":
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "function":
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
name, ok := fn["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "custom":
|
||||
if custom, ok := v["custom"].(map[string]any); ok {
|
||||
name, ok := custom["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "allowed_tools":
|
||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||
mode, ok := at["mode"].(string)
|
||||
if !ok {
|
||||
mode = ""
|
||||
}
|
||||
if mode == "required" {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
return canonical.NewToolChoiceAuto()
|
||||
}
|
||||
return canonical.NewToolChoiceAuto()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeParameters 解码请求参数
|
||||
func decodeParameters(req *ChatCompletionRequest) canonical.RequestParameters {
|
||||
params := canonical.RequestParameters{
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
}
|
||||
if req.MaxCompletionTokens != nil {
|
||||
params.MaxTokens = req.MaxCompletionTokens
|
||||
} else if req.MaxTokens != nil {
|
||||
params.MaxTokens = req.MaxTokens
|
||||
}
|
||||
if req.Stop != nil {
|
||||
params.StopSequences = normalizeStop(req.Stop)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// normalizeStop 规范化 stop 参数
|
||||
func normalizeStop(stop any) []string {
|
||||
switch v := stop.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{v}
|
||||
case []any:
|
||||
var result []string
|
||||
for _, s := range v {
|
||||
if str, ok := s.(string); ok && str != "" {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeOutputFormat 解码输出格式
|
||||
func decodeOutputFormat(format *ResponseFormat) *canonical.OutputFormat {
|
||||
if format == nil {
|
||||
return nil
|
||||
}
|
||||
switch format.Type {
|
||||
case "json_object":
|
||||
return &canonical.OutputFormat{Type: "json_object"}
|
||||
case "json_schema":
|
||||
if format.JSONSchema != nil {
|
||||
return &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: format.JSONSchema.Name,
|
||||
Schema: format.JSONSchema.Schema,
|
||||
Strict: format.JSONSchema.Strict,
|
||||
}
|
||||
}
|
||||
return &canonical.OutputFormat{Type: "json_schema"}
|
||||
case "text":
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeThinking 解码推理配置
|
||||
func decodeThinking(reasoningEffort string) *canonical.ThinkingConfig {
|
||||
if reasoningEffort == "" {
|
||||
return nil
|
||||
}
|
||||
if reasoningEffort == "none" {
|
||||
return &canonical.ThinkingConfig{Type: "disabled"}
|
||||
}
|
||||
effort := reasoningEffort
|
||||
if effort == "minimal" {
|
||||
effort = "low"
|
||||
}
|
||||
return &canonical.ThinkingConfig{Type: "enabled", Effort: effort}
|
||||
}
|
||||
|
||||
// decodeDeprecatedFields 废弃字段兼容
|
||||
func decodeDeprecatedFields(req *ChatCompletionRequest) {
|
||||
if len(req.Tools) == 0 && len(req.Functions) > 0 {
|
||||
req.Tools = make([]Tool, len(req.Functions))
|
||||
for i, f := range req.Functions {
|
||||
req.Tools[i] = Tool{
|
||||
Type: "function",
|
||||
Function: &FunctionDef{
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
Parameters: f.Parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.ToolChoice == nil && req.FunctionCall != nil {
|
||||
switch v := req.FunctionCall.(type) {
|
||||
case string:
|
||||
switch v {
|
||||
case "none":
|
||||
req.ToolChoice = "none"
|
||||
case "auto":
|
||||
req.ToolChoice = "auto"
|
||||
}
|
||||
case map[string]any:
|
||||
if name, ok := v["name"].(string); ok {
|
||||
req.ToolChoice = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{"name": name},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decodeResponse 将 OpenAI 响应解码为 Canonical 响应
|
||||
func decodeResponse(body []byte) (*canonical.CanonicalResponse, error) {
|
||||
var resp ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析 OpenAI 响应失败").WithCause(err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("")},
|
||||
Usage: canonical.CanonicalUsage{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
var blocks []canonical.ContentBlock
|
||||
|
||||
if choice.Message != nil {
|
||||
if choice.Message.Content != nil {
|
||||
text := extractText(choice.Message.Content)
|
||||
if text != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
}
|
||||
}
|
||||
if choice.Message.Refusal != "" {
|
||||
blocks = append(blocks, canonical.NewTextBlock(choice.Message.Refusal))
|
||||
}
|
||||
if choice.Message.ReasoningContent != "" {
|
||||
blocks = append(blocks, canonical.NewThinkingBlock(choice.Message.ReasoningContent))
|
||||
}
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
var input json.RawMessage
|
||||
name := ""
|
||||
if tc.Type == "custom" && tc.Custom != nil {
|
||||
input = json.RawMessage(fmt.Sprintf("%q", tc.Custom.Input))
|
||||
name = tc.Custom.Name
|
||||
} else if tc.Function != nil {
|
||||
input = json.RawMessage(tc.Function.Arguments)
|
||||
if !json.Valid(input) {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
name = tc.Function.Name
|
||||
} else {
|
||||
input = json.RawMessage("{}")
|
||||
}
|
||||
blocks = append(blocks, canonical.NewToolUseBlock(tc.ID, name, input))
|
||||
}
|
||||
}
|
||||
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
|
||||
var stopReason *canonical.StopReason
|
||||
if choice.FinishReason != nil {
|
||||
sr := mapFinishReason(*choice.FinishReason)
|
||||
stopReason = &sr
|
||||
}
|
||||
|
||||
return &canonical.CanonicalResponse{
|
||||
ID: resp.ID,
|
||||
Model: resp.Model,
|
||||
Content: blocks,
|
||||
StopReason: stopReason,
|
||||
Usage: decodeUsage(resp.Usage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mapFinishReason 映射结束原因
|
||||
func mapFinishReason(reason string) canonical.StopReason {
|
||||
switch reason {
|
||||
case "stop":
|
||||
return canonical.StopReasonEndTurn
|
||||
case "length":
|
||||
return canonical.StopReasonMaxTokens
|
||||
case "tool_calls":
|
||||
return canonical.StopReasonToolUse
|
||||
case "function_call":
|
||||
return canonical.StopReasonToolUse
|
||||
case "content_filter":
|
||||
return canonical.StopReasonContentFilter
|
||||
default:
|
||||
return canonical.StopReasonEndTurn
|
||||
}
|
||||
}
|
||||
|
||||
// decodeUsage 解码用量
|
||||
func decodeUsage(usage *Usage) canonical.CanonicalUsage {
|
||||
if usage == nil {
|
||||
return canonical.CanonicalUsage{}
|
||||
}
|
||||
result := canonical.CanonicalUsage{
|
||||
InputTokens: usage.PromptTokens,
|
||||
OutputTokens: usage.CompletionTokens,
|
||||
}
|
||||
if usage.PromptTokensDetails != nil && usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
val := usage.PromptTokensDetails.CachedTokens
|
||||
result.CacheReadTokens = &val
|
||||
}
|
||||
if usage.CompletionTokensDetails != nil && usage.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
val := usage.CompletionTokensDetails.ReasoningTokens
|
||||
result.ReasoningTokens = &val
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeModelsResponse 解码模型列表响应
|
||||
func decodeModelsResponse(body []byte) (*canonical.CanonicalModelList, error) {
|
||||
var resp ModelsResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
models := make([]canonical.CanonicalModel, len(resp.Data))
|
||||
for i, m := range resp.Data {
|
||||
models[i] = canonical.CanonicalModel{
|
||||
ID: m.ID,
|
||||
Name: m.ID,
|
||||
Created: m.Created,
|
||||
OwnedBy: m.OwnedBy,
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalModelList{Models: models}, nil
|
||||
}
|
||||
|
||||
// decodeModelInfoResponse 解码模型详情响应
|
||||
func decodeModelInfoResponse(body []byte) (*canonical.CanonicalModelInfo, error) {
|
||||
var resp ModelInfoResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalModelInfo{
|
||||
ID: resp.ID,
|
||||
Name: resp.ID,
|
||||
Created: resp.Created,
|
||||
OwnedBy: resp.OwnedBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeEmbeddingRequest 解码嵌入请求
|
||||
func decodeEmbeddingRequest(body []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
var req EmbeddingRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingRequest{
|
||||
Model: req.Model,
|
||||
Input: req.Input,
|
||||
EncodingFormat: req.EncodingFormat,
|
||||
Dimensions: req.Dimensions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeEmbeddingResponse 解码嵌入响应
|
||||
func decodeEmbeddingResponse(body []byte) (*canonical.CanonicalEmbeddingResponse, error) {
|
||||
var resp EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]canonical.EmbeddingData, len(resp.Data))
|
||||
for i, d := range resp.Data {
|
||||
data[i] = canonical.EmbeddingData{Index: d.Index, Embedding: d.Embedding}
|
||||
}
|
||||
return &canonical.CanonicalEmbeddingResponse{
|
||||
Data: data,
|
||||
Model: resp.Model,
|
||||
Usage: canonical.EmbeddingUsage{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeRerankRequest 解码重排序请求
|
||||
func decodeRerankRequest(body []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
var req RerankRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.CanonicalRerankRequest{
|
||||
Model: req.Model,
|
||||
Query: req.Query,
|
||||
Documents: req.Documents,
|
||||
TopN: req.TopN,
|
||||
ReturnDocuments: req.ReturnDocuments,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeRerankResponse 解码重排序响应
|
||||
func decodeRerankResponse(body []byte) (*canonical.CanonicalRerankResponse, error) {
|
||||
var resp RerankResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results := make([]canonical.RerankResult, len(resp.Results))
|
||||
for i, r := range resp.Results {
|
||||
results[i] = canonical.RerankResult{
|
||||
Index: r.Index,
|
||||
RelevanceScore: r.RelevanceScore,
|
||||
Document: r.Document,
|
||||
}
|
||||
}
|
||||
return &canonical.CanonicalRerankResponse{Results: results, Model: resp.Model}, nil
|
||||
}
|
||||
|
||||
// generateID 生成唯一 ID
|
||||
func generateID() string {
|
||||
return fmt.Sprintf("call_%d", generateCounter())
|
||||
}
|
||||
|
||||
var idCounter int64
|
||||
|
||||
func generateCounter() int64 {
|
||||
return atomic.AddInt64(&idCounter, 1)
|
||||
}
|
||||
433
backend/internal/conversion/openai/decoder_test.go
Normal file
433
backend/internal/conversion/openai/decoder_test.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeRequest_BasicChat(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", req.Model)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
assert.NotNil(t, req.Parameters.Temperature)
|
||||
assert.Equal(t, 0.7, *req.Parameters.Temperature)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_SystemAndDeveloper(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "system", "content": "你是助手"},
|
||||
{"role": "developer", "content": "额外指令"},
|
||||
{"role": "user", "content": "你好"}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "你是助手\n\n额外指令", req.System)
|
||||
assert.Len(t, req.Messages, 1)
|
||||
assert.Equal(t, canonical.RoleUser, req.Messages[0].Role)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "天气"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": "{\"city\":\"北京\"}"}
|
||||
}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
found := false
|
||||
for _, b := range assistantMsg.Content {
|
||||
if b.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_123", b.ID)
|
||||
assert.Equal(t, "get_weather", b.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolMessage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "天气"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "晴天 25°C"
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
toolMsg := req.Messages[2]
|
||||
assert.Equal(t, canonical.RoleTool, toolMsg.Role)
|
||||
assert.Equal(t, "call_1", toolMsg.Content[0].ToolUseID)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingModel(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_MissingMessages(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
_, err := decodeRequest(body)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "INVALID_INPUT")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_DeprecatedFunctions(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"functions": [{
|
||||
"name": "get_weather",
|
||||
"description": "获取天气",
|
||||
"parameters": {"type":"object","properties":{"city":{"type":"string"}}}
|
||||
}]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Tools, 1)
|
||||
assert.Equal(t, "get_weather", req.Tools[0].Name)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Basic(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "你好"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "chatcmpl-123", resp.ID)
|
||||
assert.Equal(t, "gpt-4", resp.Model)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "你好", resp.Content[0].Text)
|
||||
assert.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *resp.StopReason)
|
||||
assert.Equal(t, 10, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 5, resp.Usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_ToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-456",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{\"q\":\"test\"}"}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
found := false
|
||||
for _, b := range resp.Content {
|
||||
if b.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_abc", b.ID)
|
||||
assert.Equal(t, "search", b.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, canonical.StopReasonToolUse, *resp.StopReason)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Thinking(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "chatcmpl-789",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "回答",
|
||||
"reasoning_content": "思考过程"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Content, 2)
|
||||
assert.Equal(t, "回答", resp.Content[0].Text)
|
||||
assert.Equal(t, "thinking", resp.Content[1].Type)
|
||||
assert.Equal(t, "思考过程", resp.Content[1].Thinking)
|
||||
}
|
||||
|
||||
func TestDecodeModelsResponse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"id": "gpt-4", "object": "model", "created": 1700000000, "owned_by": "openai"},
|
||||
{"id": "gpt-3.5-turbo", "object": "model", "created": 1700000001, "owned_by": "openai"}
|
||||
]
|
||||
}`)
|
||||
|
||||
list, err := decodeModelsResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, list.Models, 2)
|
||||
assert.Equal(t, "gpt-4", list.Models[0].ID)
|
||||
assert.Equal(t, "gpt-3.5-turbo", list.Models[1].ID)
|
||||
assert.Equal(t, int64(1700000000), list.Models[0].Created)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRequest([]byte(`invalid json`))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "JSON_PARSE_ERROR")
|
||||
}
|
||||
|
||||
func TestDecodeRequest_Parameters(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"temperature": 0.5,
|
||||
"max_completion_tokens": 2048,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.2,
|
||||
"stop": ["STOP"]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, req.Parameters.Temperature)
|
||||
assert.Equal(t, 0.5, *req.Parameters.Temperature)
|
||||
assert.NotNil(t, req.Parameters.MaxTokens)
|
||||
assert.Equal(t, 2048, *req.Parameters.MaxTokens)
|
||||
assert.NotNil(t, req.Parameters.TopP)
|
||||
assert.Equal(t, 0.9, *req.Parameters.TopP)
|
||||
assert.NotNil(t, req.Parameters.FrequencyPenalty)
|
||||
assert.Equal(t, 0.1, *req.Parameters.FrequencyPenalty)
|
||||
assert.NotNil(t, req.Parameters.PresencePenalty)
|
||||
assert.Equal(t, 0.2, *req.Parameters.PresencePenalty)
|
||||
assert.Equal(t, []string{"STOP"}, req.Parameters.StopSequences)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ToolChoice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonBody string
|
||||
want *canonical.ToolChoice
|
||||
}{
|
||||
{
|
||||
name: "auto",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"auto"}`,
|
||||
want: canonical.NewToolChoiceAuto(),
|
||||
},
|
||||
{
|
||||
name: "none",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"none"}`,
|
||||
want: canonical.NewToolChoiceNone(),
|
||||
},
|
||||
{
|
||||
name: "required",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":"required"}`,
|
||||
want: canonical.NewToolChoiceAny(),
|
||||
},
|
||||
{
|
||||
name: "named",
|
||||
jsonBody: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"tool_choice":{"type":"function","function":{"name":"x"}}}`,
|
||||
want: canonical.NewToolChoiceNamed("x"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := decodeRequest([]byte(tt.jsonBody))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.ToolChoice)
|
||||
assert.Equal(t, tt.want.Type, req.ToolChoice.Type)
|
||||
assert.Equal(t, tt.want.Name, req.ToolChoice.Name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "my_schema",
|
||||
"schema": {"type":"object","properties":{"name":{"type":"string"}}},
|
||||
"strict": true
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_schema", req.OutputFormat.Type)
|
||||
assert.Equal(t, "my_schema", req.OutputFormat.Name)
|
||||
assert.NotNil(t, req.OutputFormat.Schema)
|
||||
require.NotNil(t, req.OutputFormat.Strict)
|
||||
assert.True(t, *req.OutputFormat.Strict)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputFormat_JSON(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"response_format": {"type": "json_object"}
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.OutputFormat)
|
||||
assert.Equal(t, "json_object", req.OutputFormat.Type)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
finishReason string
|
||||
want canonical.StopReason
|
||||
}{
|
||||
{"stop→end_turn", "stop", canonical.StopReasonEndTurn},
|
||||
{"length→max_tokens", "length", canonical.StopReasonMaxTokens},
|
||||
{"tool_calls→tool_use", "tool_calls", canonical.StopReasonToolUse},
|
||||
{"content_filter→content_filter", "content_filter", canonical.StopReasonContentFilter},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(fmt.Sprintf(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "%s"}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
|
||||
}`, tt.finishReason))
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.StopReason)
|
||||
assert.Equal(t, tt.want, *resp.StopReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Usage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"prompt_tokens_details": {"cached_tokens": 80}
|
||||
}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, resp.Usage.InputTokens)
|
||||
assert.Equal(t, 50, resp.Usage.OutputTokens)
|
||||
require.NotNil(t, resp.Usage.CacheReadTokens)
|
||||
assert.Equal(t, 80, *resp.Usage.CacheReadTokens)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_Refusal(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id": "resp-1",
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": null, "refusal": "我拒绝回答"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}
|
||||
}`)
|
||||
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
found := false
|
||||
for _, b := range resp.Content {
|
||||
if b.Text == "我拒绝回答" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_AssistantContentArray(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello back"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assistantMsg := req.Messages[1]
|
||||
assert.Equal(t, canonical.RoleAssistant, assistantMsg.Role)
|
||||
assert.Len(t, assistantMsg.Content, 1)
|
||||
assert.Equal(t, "text", assistantMsg.Content[0].Type)
|
||||
assert.Equal(t, "hello back", assistantMsg.Content[0].Text)
|
||||
}
|
||||
521
backend/internal/conversion/openai/encoder.go
Normal file
521
backend/internal/conversion/openai/encoder.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// encodeRequest 将 Canonical 请求编码为 OpenAI 请求
|
||||
func encodeRequest(req *canonical.CanonicalRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"stream": req.Stream,
|
||||
}
|
||||
|
||||
// 系统消息 + 消息
|
||||
messages := encodeSystemAndMessages(req)
|
||||
result["messages"] = messages
|
||||
|
||||
// 参数
|
||||
encodeParametersInto(req, result)
|
||||
|
||||
// 工具
|
||||
if len(req.Tools) > 0 {
|
||||
tools := make([]map[string]any, len(req.Tools))
|
||||
for i, t := range req.Tools {
|
||||
tools[i] = map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": t.Name,
|
||||
"description": t.Description,
|
||||
"parameters": t.InputSchema,
|
||||
},
|
||||
}
|
||||
}
|
||||
result["tools"] = tools
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
result["tool_choice"] = encodeToolChoice(req.ToolChoice)
|
||||
}
|
||||
|
||||
// 公共字段
|
||||
if req.UserID != "" {
|
||||
result["user"] = req.UserID
|
||||
}
|
||||
if req.OutputFormat != nil {
|
||||
result["response_format"] = encodeOutputFormat(req.OutputFormat)
|
||||
}
|
||||
if req.ParallelToolUse != nil {
|
||||
result["parallel_tool_calls"] = *req.ParallelToolUse
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
switch req.Thinking.Type {
|
||||
case "disabled":
|
||||
result["reasoning_effort"] = "none"
|
||||
default:
|
||||
if req.Thinking.Effort != "" {
|
||||
result["reasoning_effort"] = req.Thinking.Effort
|
||||
} else {
|
||||
result["reasoning_effort"] = "medium"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 请求失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// encodeSystemAndMessages 编码系统消息和消息列表
|
||||
func encodeSystemAndMessages(req *canonical.CanonicalRequest) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
// 系统消息
|
||||
switch v := req.System.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": v,
|
||||
})
|
||||
}
|
||||
case []canonical.SystemBlock:
|
||||
var parts []string
|
||||
for _, b := range v {
|
||||
parts = append(parts, b.Text)
|
||||
}
|
||||
text := strings.Join(parts, "\n\n")
|
||||
if text != "" {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": text,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 消息
|
||||
for _, msg := range req.Messages {
|
||||
encoded := encodeMessage(msg)
|
||||
messages = append(messages, encoded...)
|
||||
}
|
||||
|
||||
// 合并连续同角色消息
|
||||
return mergeConsecutiveRoles(messages)
|
||||
}
|
||||
|
||||
// encodeMessage 编码单条消息
|
||||
func encodeMessage(msg canonical.CanonicalMessage) []map[string]any {
|
||||
switch msg.Role {
|
||||
case canonical.RoleUser:
|
||||
return []map[string]any{{
|
||||
"role": "user",
|
||||
"content": encodeUserContent(msg.Content),
|
||||
}}
|
||||
case canonical.RoleAssistant:
|
||||
m := map[string]any{"role": "assistant"}
|
||||
var textParts []string
|
||||
var toolUses []canonical.ContentBlock
|
||||
|
||||
for _, b := range msg.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
textParts = append(textParts, b.Text)
|
||||
case "tool_use":
|
||||
toolUses = append(toolUses, b)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolUses) > 0 {
|
||||
if len(textParts) > 0 {
|
||||
m["content"] = strings.Join(textParts, "")
|
||||
} else {
|
||||
m["content"] = nil
|
||||
}
|
||||
tcs := make([]map[string]any, len(toolUses))
|
||||
for i, tu := range toolUses {
|
||||
tcs[i] = map[string]any{
|
||||
"id": tu.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": tu.Name,
|
||||
"arguments": string(tu.Input),
|
||||
},
|
||||
}
|
||||
}
|
||||
m["tool_calls"] = tcs
|
||||
} else if len(textParts) > 0 {
|
||||
m["content"] = strings.Join(textParts, "")
|
||||
} else {
|
||||
m["content"] = ""
|
||||
}
|
||||
return []map[string]any{m}
|
||||
|
||||
case canonical.RoleTool:
|
||||
for _, b := range msg.Content {
|
||||
if b.Type == "tool_result" {
|
||||
var contentStr string
|
||||
if b.Content != nil {
|
||||
var s string
|
||||
if json.Unmarshal(b.Content, &s) == nil {
|
||||
contentStr = s
|
||||
} else {
|
||||
contentStr = string(b.Content)
|
||||
}
|
||||
}
|
||||
return []map[string]any{{
|
||||
"role": "tool",
|
||||
"tool_call_id": b.ToolUseID,
|
||||
"content": contentStr,
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeUserContent 编码用户内容
|
||||
func encodeUserContent(blocks []canonical.ContentBlock) any {
|
||||
if len(blocks) == 1 && blocks[0].Type == "text" {
|
||||
return blocks[0].Text
|
||||
}
|
||||
parts := make([]map[string]any, 0, len(blocks))
|
||||
for _, b := range blocks {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
parts = append(parts, map[string]any{"type": "text", "text": b.Text})
|
||||
case "image":
|
||||
parts = append(parts, map[string]any{"type": "image_url"})
|
||||
case "audio":
|
||||
parts = append(parts, map[string]any{"type": "input_audio"})
|
||||
case "file":
|
||||
parts = append(parts, map[string]any{"type": "file"})
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// encodeToolChoice 编码工具选择
|
||||
func encodeToolChoice(choice *canonical.ToolChoice) any {
|
||||
switch choice.Type {
|
||||
case "auto":
|
||||
return "auto"
|
||||
case "none":
|
||||
return "none"
|
||||
case "any":
|
||||
return "required"
|
||||
case "tool":
|
||||
return map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": choice.Name,
|
||||
},
|
||||
}
|
||||
}
|
||||
return "auto"
|
||||
}
|
||||
|
||||
// encodeParametersInto 编码参数到结果 map
|
||||
func encodeParametersInto(req *canonical.CanonicalRequest, result map[string]any) {
|
||||
if req.Parameters.MaxTokens != nil {
|
||||
result["max_completion_tokens"] = *req.Parameters.MaxTokens
|
||||
}
|
||||
if req.Parameters.Temperature != nil {
|
||||
result["temperature"] = *req.Parameters.Temperature
|
||||
}
|
||||
if req.Parameters.TopP != nil {
|
||||
result["top_p"] = *req.Parameters.TopP
|
||||
}
|
||||
if req.Parameters.FrequencyPenalty != nil {
|
||||
result["frequency_penalty"] = *req.Parameters.FrequencyPenalty
|
||||
}
|
||||
if req.Parameters.PresencePenalty != nil {
|
||||
result["presence_penalty"] = *req.Parameters.PresencePenalty
|
||||
}
|
||||
if len(req.Parameters.StopSequences) > 0 {
|
||||
result["stop"] = req.Parameters.StopSequences
|
||||
}
|
||||
}
|
||||
|
||||
// encodeOutputFormat 编码输出格式
|
||||
func encodeOutputFormat(format *canonical.OutputFormat) map[string]any {
|
||||
switch format.Type {
|
||||
case "json_object":
|
||||
return map[string]any{"type": "json_object"}
|
||||
case "json_schema":
|
||||
m := map[string]any{"type": "json_schema"}
|
||||
schema := map[string]any{
|
||||
"name": format.Name,
|
||||
}
|
||||
if format.Schema != nil {
|
||||
schema["schema"] = format.Schema
|
||||
}
|
||||
if format.Strict != nil {
|
||||
schema["strict"] = *format.Strict
|
||||
}
|
||||
m["json_schema"] = schema
|
||||
return m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeResponse 将 Canonical 响应编码为 OpenAI 响应
|
||||
func encodeResponse(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
var textParts []string
|
||||
var thinkingParts []string
|
||||
var toolUses []canonical.ContentBlock
|
||||
|
||||
for _, b := range resp.Content {
|
||||
switch b.Type {
|
||||
case "text":
|
||||
textParts = append(textParts, b.Text)
|
||||
case "thinking":
|
||||
thinkingParts = append(thinkingParts, b.Thinking)
|
||||
case "tool_use":
|
||||
toolUses = append(toolUses, b)
|
||||
}
|
||||
}
|
||||
|
||||
message := map[string]any{"role": "assistant"}
|
||||
if len(toolUses) > 0 {
|
||||
if len(textParts) > 0 {
|
||||
message["content"] = strings.Join(textParts, "")
|
||||
} else {
|
||||
message["content"] = nil
|
||||
}
|
||||
tcs := make([]map[string]any, len(toolUses))
|
||||
for i, tu := range toolUses {
|
||||
tcs[i] = map[string]any{
|
||||
"id": tu.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": tu.Name,
|
||||
"arguments": string(tu.Input),
|
||||
},
|
||||
}
|
||||
}
|
||||
message["tool_calls"] = tcs
|
||||
} else if len(textParts) > 0 {
|
||||
message["content"] = strings.Join(textParts, "")
|
||||
} else {
|
||||
message["content"] = ""
|
||||
}
|
||||
|
||||
if len(thinkingParts) > 0 {
|
||||
message["reasoning_content"] = strings.Join(thinkingParts, "")
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if resp.StopReason != nil {
|
||||
fr := mapCanonicalToFinishReason(*resp.StopReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"id": resp.ID,
|
||||
"object": "chat.completion",
|
||||
"created": time.Now().Unix(),
|
||||
"model": resp.Model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
}},
|
||||
"usage": encodeUsage(resp.Usage),
|
||||
}
|
||||
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeEncodingFailure, "编码 OpenAI 响应失败").WithCause(err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// mapCanonicalToFinishReason 映射 Canonical 停止原因到 OpenAI finish_reason
|
||||
func mapCanonicalToFinishReason(reason canonical.StopReason) string {
|
||||
switch reason {
|
||||
case canonical.StopReasonEndTurn:
|
||||
return "stop"
|
||||
case canonical.StopReasonMaxTokens:
|
||||
return "length"
|
||||
case canonical.StopReasonToolUse:
|
||||
return "tool_calls"
|
||||
case canonical.StopReasonContentFilter:
|
||||
return "content_filter"
|
||||
case canonical.StopReasonStopSequence:
|
||||
return "stop"
|
||||
case canonical.StopReasonRefusal:
|
||||
return "stop"
|
||||
default:
|
||||
return "stop"
|
||||
}
|
||||
}
|
||||
|
||||
// encodeUsage 编码用量
|
||||
func encodeUsage(usage canonical.CanonicalUsage) map[string]any {
|
||||
result := map[string]any{
|
||||
"prompt_tokens": usage.InputTokens,
|
||||
"completion_tokens": usage.OutputTokens,
|
||||
"total_tokens": usage.InputTokens + usage.OutputTokens,
|
||||
}
|
||||
if usage.CacheReadTokens != nil && *usage.CacheReadTokens > 0 {
|
||||
result["prompt_tokens_details"] = map[string]any{
|
||||
"cached_tokens": *usage.CacheReadTokens,
|
||||
}
|
||||
}
|
||||
if usage.ReasoningTokens != nil && *usage.ReasoningTokens > 0 {
|
||||
result["completion_tokens_details"] = map[string]any{
|
||||
"reasoning_tokens": *usage.ReasoningTokens,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeModelsResponse 编码模型列表响应
|
||||
func encodeModelsResponse(list *canonical.CanonicalModelList) ([]byte, error) {
|
||||
data := make([]map[string]any, len(list.Models))
|
||||
for i, m := range list.Models {
|
||||
created := int64(0)
|
||||
if m.Created != 0 {
|
||||
created = m.Created
|
||||
}
|
||||
ownedBy := "unknown"
|
||||
if m.OwnedBy != "" {
|
||||
ownedBy = m.OwnedBy
|
||||
}
|
||||
data[i] = map[string]any{
|
||||
"id": m.ID,
|
||||
"object": "model",
|
||||
"created": created,
|
||||
"owned_by": ownedBy,
|
||||
}
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeModelInfoResponse 编码模型详情响应
|
||||
func encodeModelInfoResponse(info *canonical.CanonicalModelInfo) ([]byte, error) {
|
||||
created := int64(0)
|
||||
if info.Created != 0 {
|
||||
created = info.Created
|
||||
}
|
||||
ownedBy := "unknown"
|
||||
if info.OwnedBy != "" {
|
||||
ownedBy = info.OwnedBy
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"id": info.ID,
|
||||
"object": "model",
|
||||
"created": created,
|
||||
"owned_by": ownedBy,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeEmbeddingRequest 编码嵌入请求
|
||||
func encodeEmbeddingRequest(req *canonical.CanonicalEmbeddingRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"input": req.Input,
|
||||
}
|
||||
if req.EncodingFormat != "" {
|
||||
result["encoding_format"] = req.EncodingFormat
|
||||
}
|
||||
if req.Dimensions != nil {
|
||||
result["dimensions"] = *req.Dimensions
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// encodeEmbeddingResponse 编码嵌入响应
|
||||
func encodeEmbeddingResponse(resp *canonical.CanonicalEmbeddingResponse) ([]byte, error) {
|
||||
data := make([]map[string]any, len(resp.Data))
|
||||
for i, d := range resp.Data {
|
||||
data[i] = map[string]any{
|
||||
"index": d.Index,
|
||||
"embedding": d.Embedding,
|
||||
}
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": resp.Model,
|
||||
"usage": resp.Usage,
|
||||
})
|
||||
}
|
||||
|
||||
// encodeRerankRequest 编码重排序请求
|
||||
func encodeRerankRequest(req *canonical.CanonicalRerankRequest, provider *conversion.TargetProvider) ([]byte, error) {
|
||||
result := map[string]any{
|
||||
"model": provider.ModelName,
|
||||
"query": req.Query,
|
||||
"documents": req.Documents,
|
||||
}
|
||||
if req.TopN != nil {
|
||||
result["top_n"] = *req.TopN
|
||||
}
|
||||
if req.ReturnDocuments != nil {
|
||||
result["return_documents"] = *req.ReturnDocuments
|
||||
}
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// encodeRerankResponse 编码重排序响应
|
||||
func encodeRerankResponse(resp *canonical.CanonicalRerankResponse) ([]byte, error) {
|
||||
results := make([]map[string]any, len(resp.Results))
|
||||
for i, r := range resp.Results {
|
||||
m := map[string]any{
|
||||
"index": r.Index,
|
||||
"relevance_score": r.RelevanceScore,
|
||||
}
|
||||
if r.Document != nil {
|
||||
m["document"] = *r.Document
|
||||
}
|
||||
results[i] = m
|
||||
}
|
||||
return json.Marshal(map[string]any{
|
||||
"results": results,
|
||||
"model": resp.Model,
|
||||
})
|
||||
}
|
||||
|
||||
// mergeConsecutiveRoles 合并连续同角色消息(拼接内容)
|
||||
func mergeConsecutiveRoles(messages []map[string]any) []map[string]any {
|
||||
if len(messages) <= 1 {
|
||||
return messages
|
||||
}
|
||||
var result []map[string]any
|
||||
for _, msg := range messages {
|
||||
if len(result) > 0 {
|
||||
lastRole := result[len(result)-1]["role"]
|
||||
currRole := msg["role"]
|
||||
if lastRole == currRole {
|
||||
lastContent := result[len(result)-1]["content"]
|
||||
currContent := msg["content"]
|
||||
switch lv := lastContent.(type) {
|
||||
case string:
|
||||
if cv, ok := currContent.(string); ok {
|
||||
result[len(result)-1]["content"] = lv + cv
|
||||
}
|
||||
case []any:
|
||||
if cv, ok := currContent.([]any); ok {
|
||||
result[len(result)-1]["content"] = append(lv, cv...)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result
|
||||
}
|
||||
372
backend/internal/conversion/openai/encoder_test.go
Normal file
372
backend/internal/conversion/openai/encoder_test.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeRequest_Basic(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Stream: true,
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-model", result["model"])
|
||||
assert.Equal(t, true, result["stream"])
|
||||
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_SystemInjection(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
System: "你是助手",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 2)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "system", firstMsg["role"])
|
||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolCalls(t *testing.T) {
|
||||
input := json.RawMessage(`{"city":"北京"}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{
|
||||
{
|
||||
Role: canonical.RoleAssistant,
|
||||
Content: []canonical.ContentBlock{
|
||||
canonical.NewToolUseBlock("call_1", "get_weather", input),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assistantMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, toolCalls, 1)
|
||||
tc, ok := toolCalls[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_1", tc["id"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_Thinking(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "enabled", Effort: "high"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "high", result["reasoning_effort"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Basic(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("你好")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "resp-1", result["id"])
|
||||
assert.Equal(t, "chat.completion", result["object"])
|
||||
|
||||
choices, ok := result["choices"].([]any)
|
||||
require.True(t, ok)
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
msg, ok := choice["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "你好", msg["content"])
|
||||
assert.Equal(t, "stop", choice["finish_reason"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
sr := canonical.StopReasonToolUse
|
||||
input := json.RawMessage(`{"q":"test"}`)
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-2",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewToolUseBlock("call_1", "search", input)},
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices, okc := result["choices"].([]any)
|
||||
require.True(t, okc)
|
||||
msgMap, okm := choices[0].(map[string]any)
|
||||
require.True(t, okm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
tcs, ok := msg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tcs, 1)
|
||||
}
|
||||
|
||||
func TestEncodeModelsResponse(t *testing.T) {
|
||||
list := &canonical.CanonicalModelList{
|
||||
Models: []canonical.CanonicalModel{
|
||||
{ID: "gpt-4", Created: 1700000000, OwnedBy: "openai"},
|
||||
{ID: "gpt-3.5-turbo", Created: 1700000001, OwnedBy: "openai"},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeModelsResponse(list)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
data, okd := result["data"].([]any)
|
||||
require.True(t, okd)
|
||||
assert.Len(t, data, 2)
|
||||
}
|
||||
|
||||
func TestMergeConsecutiveRoles(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "user", "content": "B"},
|
||||
{"role": "assistant", "content": "C"},
|
||||
{"role": "assistant", "content": "D"},
|
||||
}
|
||||
|
||||
result := mergeConsecutiveRoles(messages)
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, "AB", result[0]["content"])
|
||||
assert.Equal(t, "CD", result[1]["content"])
|
||||
}
|
||||
|
||||
func TestMergeConsecutiveRoles_NotOverwriting(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "user", "content": "你好"},
|
||||
{"role": "user", "content": "世界"},
|
||||
}
|
||||
|
||||
result := mergeConsecutiveRoles(messages)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "你好世界", result[0]["content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Auto(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceAuto(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "auto", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_None(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceNone(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "none", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Required(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceAny(),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "required", result["tool_choice"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_ToolChoice_Named(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
ToolChoice: canonical.NewToolChoiceNamed("my_func"),
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tc, ok := result["tool_choice"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "function", tc["type"])
|
||||
fn, ok := tc["function"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my_func", fn["name"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSONSchema(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{
|
||||
Type: "json_schema",
|
||||
Name: "my_schema",
|
||||
Schema: schema,
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
rf, ok := result["response_format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_schema", rf["type"])
|
||||
js, ok := rf["json_schema"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my_schema", js["name"])
|
||||
assert.NotNil(t, js["schema"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_Text(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
_, hasResponseFormat := result["response_format"]
|
||||
assert.False(t, hasResponseFormat)
|
||||
}
|
||||
|
||||
func TestEncodeResponse_Thinking(t *testing.T) {
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "resp-thinking",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{
|
||||
canonical.NewTextBlock("回答"),
|
||||
canonical.NewThinkingBlock("思考过程"),
|
||||
},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
assert.Equal(t, "回答", msg["content"])
|
||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_Parameters(t *testing.T) {
|
||||
temp := 0.5
|
||||
maxTokens := 2048
|
||||
topP := 0.9
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Parameters: canonical.RequestParameters{
|
||||
Temperature: &temp,
|
||||
MaxTokens: &maxTokens,
|
||||
TopP: &topP,
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, temp, result["temperature"])
|
||||
assert.Equal(t, float64(maxTokens), result["max_completion_tokens"])
|
||||
assert.Equal(t, topP, result["top_p"])
|
||||
stop, ok := result["stop"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, stop, 2)
|
||||
assert.Equal(t, "STOP", stop[0])
|
||||
assert.Equal(t, "END", stop[1])
|
||||
}
|
||||
230
backend/internal/conversion/openai/stream_decoder.go
Normal file
230
backend/internal/conversion/openai/stream_decoder.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder OpenAI 流式解码器
|
||||
type StreamDecoder struct {
|
||||
messageStarted bool
|
||||
openBlocks map[int]string
|
||||
textBlockIndex int
|
||||
thinkingBlockIndex int
|
||||
refusalBlockIndex int
|
||||
toolCallIDMap map[int]string
|
||||
toolCallNameMap map[int]string
|
||||
nextToolCallIdx int
|
||||
utf8Remainder []byte
|
||||
accumulatedUsage *canonical.CanonicalUsage
|
||||
}
|
||||
|
||||
// NewStreamDecoder 创建 OpenAI 流式解码器
|
||||
func NewStreamDecoder() *StreamDecoder {
|
||||
return &StreamDecoder{
|
||||
openBlocks: make(map[int]string),
|
||||
toolCallIDMap: make(map[int]string),
|
||||
toolCallNameMap: make(map[int]string),
|
||||
textBlockIndex: -1,
|
||||
thinkingBlockIndex: -1,
|
||||
refusalBlockIndex: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 处理原始 SSE chunk
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
// 处理 UTF-8 残余
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
|
||||
// 解析 SSE data 行
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
if payload == "[DONE]" {
|
||||
events = append(events, d.flushOpenBlocks()...)
|
||||
return events
|
||||
}
|
||||
|
||||
chunkEvents := d.processDataChunk([]byte(payload))
|
||||
events = append(events, chunkEvents...)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 刷新解码器状态
|
||||
func (d *StreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDataChunk 处理单个 data chunk
|
||||
func (d *StreamDecoder) processDataChunk(data []byte) []canonical.CanonicalStreamEvent {
|
||||
// 检查 UTF-8 完整性
|
||||
if !utf8.Valid(data) {
|
||||
validEnd := len(data)
|
||||
for !utf8.Valid(data[:validEnd]) {
|
||||
validEnd--
|
||||
}
|
||||
d.utf8Remainder = append(d.utf8Remainder, data[validEnd:]...)
|
||||
data = data[:validEnd]
|
||||
}
|
||||
|
||||
var chunk StreamChunk
|
||||
if err := json.Unmarshal(data, &chunk); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
|
||||
// 首个 chunk: MessageStart
|
||||
if !d.messageStarted {
|
||||
events = append(events, canonical.NewMessageStartEvent(chunk.ID, chunk.Model))
|
||||
d.messageStarted = true
|
||||
}
|
||||
|
||||
for _, choice := range chunk.Choices {
|
||||
if choice.Delta == nil {
|
||||
continue
|
||||
}
|
||||
delta := choice.Delta
|
||||
|
||||
// text content
|
||||
if delta.Content != nil {
|
||||
text := ""
|
||||
switch v := delta.Content.(type) {
|
||||
case string:
|
||||
text = v
|
||||
default:
|
||||
text = fmt.Sprintf("%v", v)
|
||||
}
|
||||
if text != "" {
|
||||
if _, ok := d.openBlocks[d.textBlockIndex]; !ok || d.textBlockIndex < 0 {
|
||||
d.textBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.textBlockIndex] = "text"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.textBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.textBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: text}))
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning_content (非标准)
|
||||
if delta.ReasoningContent != "" {
|
||||
if _, ok := d.openBlocks[d.thinkingBlockIndex]; !ok || d.thinkingBlockIndex < 0 {
|
||||
d.thinkingBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.thinkingBlockIndex] = "thinking"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.thinkingBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "thinking", Thinking: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.thinkingBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeThinking), Thinking: delta.ReasoningContent}))
|
||||
}
|
||||
|
||||
// refusal
|
||||
if delta.Refusal != "" {
|
||||
if _, ok := d.openBlocks[d.refusalBlockIndex]; !ok || d.refusalBlockIndex < 0 {
|
||||
d.refusalBlockIndex = d.allocateBlockIndex()
|
||||
d.openBlocks[d.refusalBlockIndex] = "text"
|
||||
events = append(events, canonical.NewContentBlockStartEvent(d.refusalBlockIndex,
|
||||
canonical.StreamContentBlock{Type: "text", Text: ""}))
|
||||
}
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(d.refusalBlockIndex,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeText), Text: delta.Refusal}))
|
||||
}
|
||||
|
||||
// tool_calls
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
for _, tc := range delta.ToolCalls {
|
||||
tcIdx := 0
|
||||
if tc.Index != nil {
|
||||
tcIdx = *tc.Index
|
||||
}
|
||||
|
||||
if tc.ID != "" {
|
||||
// 新 tool call block
|
||||
d.toolCallIDMap[tcIdx] = tc.ID
|
||||
if tc.Function != nil {
|
||||
d.toolCallNameMap[tcIdx] = tc.Function.Name
|
||||
}
|
||||
blockIdx := d.allocateBlockIndex()
|
||||
d.openBlocks[blockIdx] = fmt.Sprintf("tool_use_%d", tcIdx)
|
||||
name := d.toolCallNameMap[tcIdx]
|
||||
events = append(events, canonical.NewContentBlockStartEvent(blockIdx,
|
||||
canonical.StreamContentBlock{Type: "tool_use", ID: tc.ID, Name: name}))
|
||||
}
|
||||
|
||||
// 查找该 tool call 的 block index
|
||||
blockIdx := d.findToolUseBlockIndex(tcIdx)
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
events = append(events, canonical.NewContentBlockDeltaEvent(blockIdx,
|
||||
canonical.StreamDelta{Type: string(canonical.DeltaTypeInputJSON), PartialJSON: tc.Function.Arguments}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finish_reason
|
||||
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
||||
events = append(events, d.flushOpenBlocks()...)
|
||||
sr := mapFinishReason(*choice.FinishReason)
|
||||
events = append(events, canonical.NewMessageDeltaEventWithUsage(sr, nil))
|
||||
events = append(events, canonical.NewMessageStopEvent())
|
||||
}
|
||||
}
|
||||
|
||||
// usage chunk (choices 为空)
|
||||
if len(chunk.Choices) == 0 && chunk.Usage != nil {
|
||||
usage := decodeUsage(chunk.Usage)
|
||||
d.accumulatedUsage = &usage
|
||||
events = append(events, canonical.NewMessageDeltaEventWithUsage(canonical.StopReasonEndTurn, &usage))
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// allocateBlockIndex 分配 block 索引
|
||||
func (d *StreamDecoder) allocateBlockIndex() int {
|
||||
maxIdx := -1
|
||||
for k := range d.openBlocks {
|
||||
if k > maxIdx {
|
||||
maxIdx = k
|
||||
}
|
||||
}
|
||||
return maxIdx + 1
|
||||
}
|
||||
|
||||
// findToolUseBlockIndex 查找 tool use block 索引
|
||||
func (d *StreamDecoder) findToolUseBlockIndex(tcIdx int) int {
|
||||
key := fmt.Sprintf("tool_use_%d", tcIdx)
|
||||
for blockIdx, typ := range d.openBlocks {
|
||||
if typ == key {
|
||||
return blockIdx
|
||||
}
|
||||
}
|
||||
return d.allocateBlockIndex()
|
||||
}
|
||||
|
||||
// flushOpenBlocks 关闭所有 open blocks
|
||||
func (d *StreamDecoder) flushOpenBlocks() []canonical.CanonicalStreamEvent {
|
||||
var events []canonical.CanonicalStreamEvent
|
||||
for idx := range d.openBlocks {
|
||||
events = append(events, canonical.NewContentBlockStopEvent(idx))
|
||||
}
|
||||
d.openBlocks = make(map[int]string)
|
||||
return events
|
||||
}
|
||||
472
backend/internal/conversion/openai/stream_decoder_test.go
Normal file
472
backend/internal/conversion/openai/stream_decoder_test.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeSSEData(payload string) []byte {
|
||||
return []byte("data: " + payload + "\n\n")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_BasicText(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "你好"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
foundStart := false
|
||||
foundDelta := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageStart {
|
||||
foundStart = true
|
||||
assert.Equal(t, "chatcmpl-1", e.Message.ID)
|
||||
}
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil {
|
||||
foundDelta = true
|
||||
assert.Equal(t, "text_delta", e.Delta.Type)
|
||||
assert.Equal(t, "你好", e.Delta.Text)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStart)
|
||||
assert.True(t, foundDelta)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ToolCalls(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
idx := 0
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx,
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"北京\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
found = true
|
||||
assert.Equal(t, "call_1", e.ContentBlock.ID)
|
||||
assert.Equal(t, "get_weather", e.ContentBlock.Name)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Thinking(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"reasoning_content": "思考中",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "thinking_delta" {
|
||||
found = true
|
||||
assert.Equal(t, "思考中", e.Delta.Thinking)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_FinishReason(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
foundStop := false
|
||||
foundMsgStop := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageDelta && e.StopReason != nil {
|
||||
foundStop = true
|
||||
assert.Equal(t, canonical.StopReasonEndTurn, *e.StopReason)
|
||||
}
|
||||
if e.Type == canonical.EventMessageStop {
|
||||
foundMsgStop = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStop)
|
||||
assert.True(t, foundMsgStop)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_DoneSignal(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// 先发送一个文本 chunk
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "hi"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := append(makeSSEData(string(data)), []byte("data: [DONE]\n\n")...)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
// 应该包含 block stop 事件([DONE] 触发 flushOpenBlocks)
|
||||
foundBlockStop := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockStop {
|
||||
foundBlockStop = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundBlockStop)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_RefusalReuse(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
// 连续两个 refusal delta chunk
|
||||
for _, text := range []string{"拒绝", "原因"} {
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-1",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"refusal": text},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
raw := makeSSEData(string(data))
|
||||
events := d.ProcessChunk(raw)
|
||||
_ = events
|
||||
}
|
||||
|
||||
// 检查只创建了一个 text block(refusal 复用同一个 block)
|
||||
assert.Contains(t, d.openBlocks, d.refusalBlockIndex)
|
||||
}
|
||||
|
||||
func makeChunkSSE(chunk map[string]any) []byte {
|
||||
data, _ := json.Marshal(chunk)
|
||||
return []byte("data: " + string(data) + "\n\n")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_UsageChunk(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-usage",
|
||||
"object": "chat.completion.chunk",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
}
|
||||
raw := makeChunkSSE(chunk)
|
||||
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventMessageDelta {
|
||||
found = true
|
||||
require.NotNil(t, e.Usage)
|
||||
assert.Equal(t, 100, e.Usage.InputTokens)
|
||||
assert.Equal(t, 50, e.Usage.OutputTokens)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MultipleToolCalls(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
idx0 := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx0,
|
||||
"id": "call_a",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "func_a",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
idx1 := 1
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-mt",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx1,
|
||||
"id": "call_b",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "func_b",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events1 := d.ProcessChunk(makeChunkSSE(chunk1))
|
||||
require.NotEmpty(t, events1)
|
||||
|
||||
events2 := d.ProcessChunk(makeChunkSSE(chunk2))
|
||||
require.NotEmpty(t, events2)
|
||||
|
||||
blockIndices := map[int]bool{}
|
||||
for _, e := range append(events1, events2...) {
|
||||
if e.Type == canonical.EventContentBlockStart && e.ContentBlock != nil && e.ContentBlock.Type == "tool_use" {
|
||||
require.NotNil(t, e.Index)
|
||||
blockIndices[*e.Index] = true
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, len(blockIndices), "两个 tool call 应分配不同的 block 索引")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_Flush(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
result := d.Flush()
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_MultipleChunks_Text(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "你好"},
|
||||
},
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-multi",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "世界"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
raw := append(makeChunkSSE(chunk1), makeChunkSSE(chunk2)...)
|
||||
events := d.ProcessChunk(raw)
|
||||
|
||||
deltas := []string{}
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "text_delta" {
|
||||
deltas = append(deltas, e.Delta.Text)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, []string{"你好", "世界"}, deltas)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_UTF8Truncation(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": "chatcmpl-utf8",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"content": "你"},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
sseData := []byte("data: " + string(data) + "\n\n")
|
||||
|
||||
mid := len(sseData) - 5
|
||||
part1 := sseData[:mid]
|
||||
part2 := sseData[mid:]
|
||||
|
||||
events1 := d.ProcessChunk(part1)
|
||||
for _, e := range events1 {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil {
|
||||
assert.Equal(t, "你", e.Delta.Text)
|
||||
}
|
||||
}
|
||||
|
||||
events2 := d.ProcessChunk(part2)
|
||||
_ = events2
|
||||
}
|
||||
|
||||
func TestStreamDecoder_ToolCallSubsequentDelta(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
idx := 0
|
||||
chunk1 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx,
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
chunk2 := map[string]any{
|
||||
"id": "chatcmpl-tc",
|
||||
"model": "gpt-4",
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"index": &idx,
|
||||
"function": map[string]any{
|
||||
"arguments": "{\"city\":\"Beijing\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events1 := d.ProcessChunk(makeChunkSSE(chunk1))
|
||||
require.NotEmpty(t, events1)
|
||||
|
||||
events2 := d.ProcessChunk(makeChunkSSE(chunk2))
|
||||
require.NotEmpty(t, events2)
|
||||
|
||||
foundInputJSON := false
|
||||
for _, e := range events2 {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil && e.Delta.Type == "input_json_delta" {
|
||||
foundInputJSON = true
|
||||
assert.Equal(t, "{\"city\":\"Beijing\"}", e.Delta.PartialJSON)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundInputJSON, "subsequent tool call delta should emit input_json_delta")
|
||||
}
|
||||
|
||||
func TestStreamDecoder_InvalidJSON(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := []byte("data: {invalid json}\n\n")
|
||||
events := d.ProcessChunk(raw)
|
||||
assert.Nil(t, events)
|
||||
}
|
||||
|
||||
func TestStreamDecoder_NonDataLines(t *testing.T) {
|
||||
d := NewStreamDecoder()
|
||||
|
||||
raw := []byte(": comment line\ndata: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")
|
||||
events := d.ProcessChunk(raw)
|
||||
require.NotEmpty(t, events)
|
||||
found := false
|
||||
for _, e := range events {
|
||||
if e.Type == canonical.EventContentBlockDelta && e.Delta != nil {
|
||||
found = true
|
||||
assert.Equal(t, "hi", e.Delta.Text)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
212
backend/internal/conversion/openai/stream_encoder.go
Normal file
212
backend/internal/conversion/openai/stream_encoder.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamEncoder OpenAI 流式编码器
|
||||
type StreamEncoder struct {
|
||||
bufferedStart *canonical.CanonicalStreamEvent
|
||||
toolCallIndexMap map[string]int
|
||||
nextToolCallIndex int
|
||||
}
|
||||
|
||||
// NewStreamEncoder 创建 OpenAI 流式编码器
|
||||
func NewStreamEncoder() *StreamEncoder {
|
||||
return &StreamEncoder{
|
||||
toolCallIndexMap: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeEvent 编码 Canonical 事件为 SSE chunk
|
||||
func (e *StreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
switch event.Type {
|
||||
case canonical.EventMessageStart:
|
||||
return e.encodeMessageStart(event)
|
||||
case canonical.EventContentBlockStart:
|
||||
return e.bufferBlockStart(event)
|
||||
case canonical.EventContentBlockDelta:
|
||||
return e.encodeContentBlockDelta(event)
|
||||
case canonical.EventContentBlockStop:
|
||||
return nil
|
||||
case canonical.EventMessageDelta:
|
||||
return e.encodeMessageDelta(event)
|
||||
case canonical.EventMessageStop:
|
||||
return [][]byte{[]byte("data: [DONE]\n\n")}
|
||||
case canonical.EventPing, canonical.EventError:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区
|
||||
func (e *StreamEncoder) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeMessageStart 编码消息开始事件
|
||||
func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
id := ""
|
||||
model := ""
|
||||
if event.Message != nil {
|
||||
id = event.Message.ID
|
||||
model = event.Message.Model
|
||||
}
|
||||
|
||||
chunk := map[string]any{
|
||||
"id": id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": time.Now().Unix(),
|
||||
"model": model,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{"role": "assistant"},
|
||||
}},
|
||||
}
|
||||
|
||||
return e.marshalChunk(chunk)
|
||||
}
|
||||
|
||||
// bufferBlockStart 缓冲 block start 事件
|
||||
func (e *StreamEncoder) bufferBlockStart(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
e.bufferedStart = &event
|
||||
if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" {
|
||||
idx := e.nextToolCallIndex
|
||||
e.nextToolCallIndex++
|
||||
if event.ContentBlock.ID != "" {
|
||||
e.toolCallIndexMap[event.ContentBlock.ID] = idx
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeContentBlockDelta 编码内容块增量事件
|
||||
func (e *StreamEncoder) encodeContentBlockDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if event.Delta == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch canonical.DeltaType(event.Delta.Type) {
|
||||
case canonical.DeltaTypeText:
|
||||
return e.encodeTextDelta(event)
|
||||
case canonical.DeltaTypeInputJSON:
|
||||
return e.encodeInputJSONDelta(event)
|
||||
case canonical.DeltaTypeThinking:
|
||||
return e.encodeThinkingDelta(event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeTextDelta 编码文本增量
|
||||
func (e *StreamEncoder) encodeTextDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{
|
||||
"content": event.Delta.Text,
|
||||
}
|
||||
if e.bufferedStart != nil {
|
||||
e.bufferedStart = nil
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeInputJSONDelta 编码 JSON 输入增量
|
||||
func (e *StreamEncoder) encodeInputJSONDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if e.bufferedStart != nil && e.bufferedStart.ContentBlock != nil {
|
||||
// 首次 delta,含 id 和 name
|
||||
start := e.bufferedStart.ContentBlock
|
||||
tcIdx := 0
|
||||
if start.ID != "" {
|
||||
tcIdx = e.toolCallIndexMap[start.ID]
|
||||
}
|
||||
delta := map[string]any{
|
||||
"tool_calls": []map[string]any{{
|
||||
"index": tcIdx,
|
||||
"id": start.ID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": start.Name,
|
||||
"arguments": event.Delta.PartialJSON,
|
||||
},
|
||||
}},
|
||||
}
|
||||
e.bufferedStart = nil
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// 后续 delta,仅含 arguments
|
||||
// 使用 canonical 事件中的 index 直接映射到 OpenAI tool_calls index
|
||||
tcIdx := 0
|
||||
if event.Index != nil {
|
||||
tcIdx = *event.Index
|
||||
}
|
||||
delta := map[string]any{
|
||||
"tool_calls": []map[string]any{{
|
||||
"index": tcIdx,
|
||||
"function": map[string]any{
|
||||
"arguments": event.Delta.PartialJSON,
|
||||
},
|
||||
}},
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeThinkingDelta 编码思考增量
|
||||
func (e *StreamEncoder) encodeThinkingDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
delta := map[string]any{
|
||||
"reasoning_content": event.Delta.Thinking,
|
||||
}
|
||||
if e.bufferedStart != nil {
|
||||
e.bufferedStart = nil
|
||||
}
|
||||
return e.encodeDelta(delta)
|
||||
}
|
||||
|
||||
// encodeMessageDelta 编码消息增量事件
|
||||
func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
var chunks [][]byte
|
||||
|
||||
if event.StopReason != nil {
|
||||
fr := mapCanonicalToFinishReason(*event.StopReason)
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": map[string]any{},
|
||||
"finish_reason": fr,
|
||||
}},
|
||||
}
|
||||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||||
}
|
||||
|
||||
if event.Usage != nil {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{},
|
||||
"usage": encodeUsage(*event.Usage),
|
||||
}
|
||||
chunks = append(chunks, e.marshalChunk(chunk)...)
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// encodeDelta 编码 delta 到 SSE chunk
|
||||
func (e *StreamEncoder) encodeDelta(delta map[string]any) [][]byte {
|
||||
chunk := map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
}},
|
||||
}
|
||||
return e.marshalChunk(chunk)
|
||||
}
|
||||
|
||||
// marshalChunk 序列化 chunk 为 SSE data
|
||||
func (e *StreamEncoder) marshalChunk(chunk map[string]any) [][]byte {
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return [][]byte{[]byte(fmt.Sprintf("data: %s\n\n", data))}
|
||||
}
|
||||
289
backend/internal/conversion/openai/stream_encoder_test.go
Normal file
289
backend/internal/conversion/openai/stream_encoder_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEvent("chatcmpl-1", "gpt-4")
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "data: "))
|
||||
assert.Contains(t, s, "chatcmpl-1")
|
||||
assert.Contains(t, s, "chat.completion.chunk")
|
||||
|
||||
var payload map[string]any
|
||||
data := strings.TrimPrefix(s, "data: ")
|
||||
data = strings.TrimRight(data, "\n")
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||
choices, okch := payload["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
delta, okd := msgMap["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "assistant", delta["role"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_TextDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "你好"})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "你好")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStop(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStopEvent()
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
assert.Equal(t, "data: [DONE]\n\n", string(chunks[0]))
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Buffering(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
// ContentBlockStart 应被缓冲,不输出
|
||||
startEvent := canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{Type: "text", Text: ""})
|
||||
chunks := e.EncodeEvent(startEvent)
|
||||
assert.Nil(t, chunks)
|
||||
assert.NotNil(t, e.bufferedStart)
|
||||
|
||||
// 第一个 delta 触发输出(清空缓冲)
|
||||
deltaEvent := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{Type: "text_delta", Text: "hello"})
|
||||
chunks = e.EncodeEvent(deltaEvent)
|
||||
require.NotEmpty(t, chunks)
|
||||
assert.Nil(t, e.bufferedStart)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockStop_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
idx := 0
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventContentBlockStop,
|
||||
Index: &idx,
|
||||
}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Ping_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewPingEvent()
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Error_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewErrorEvent("test_error", "测试错误")
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_Flush_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
chunks := e.Flush()
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ThinkingDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeThinking),
|
||||
Thinking: "思考内容",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "reasoning_content")
|
||||
assert.Contains(t, s, "思考内容")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_InputJSONDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
}))
|
||||
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: "{\"city\":\"北京\"}",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "tool_calls")
|
||||
assert.Contains(t, s, "北京")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
sr := canonical.StopReasonEndTurn
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
StopReason: &sr,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "finish_reason")
|
||||
assert.Contains(t, s, "stop")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
usage := canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
}
|
||||
event := canonical.CanonicalStreamEvent{
|
||||
Type: canonical.EventMessageDelta,
|
||||
Usage: &usage,
|
||||
}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "usage")
|
||||
assert.Contains(t, s, "prompt_tokens")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_InputJSONDelta_SubsequentDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
}))
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: "{\"city\":",
|
||||
}))
|
||||
|
||||
event := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: "\"Beijing\"}",
|
||||
})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.NotEmpty(t, chunks)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "tool_calls")
|
||||
assert.Contains(t, s, "Beijing")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStart_NilMessage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.CanonicalStreamEvent{Type: canonical.EventMessageStart}
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, "chat.completion.chunk")
|
||||
}
|
||||
|
||||
func TestStreamEncoder_UnknownEvent_ReturnsNil(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.CanonicalStreamEvent{Type: "unknown_type"}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockDelta_NilDelta(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.CanonicalStreamEvent{Type: canonical.EventContentBlockDelta}
|
||||
chunks := e.EncodeEvent(event)
|
||||
assert.Nil(t, chunks)
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MultiToolCall_IndexMapping(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockStartEvent(0, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
}))
|
||||
|
||||
firstDelta := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: `{"city":"北京"}`,
|
||||
})
|
||||
chunks := e.EncodeEvent(firstDelta)
|
||||
require.NotEmpty(t, chunks)
|
||||
s := string(chunks[0])
|
||||
assert.Contains(t, s, `"index":0`)
|
||||
assert.Contains(t, s, "get_weather")
|
||||
assert.Contains(t, s, "北京")
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockStopEvent(0))
|
||||
|
||||
e.EncodeEvent(canonical.NewContentBlockStartEvent(1, canonical.StreamContentBlock{
|
||||
Type: "tool_use",
|
||||
ID: "call_2",
|
||||
Name: "get_time",
|
||||
}))
|
||||
|
||||
secondDelta := canonical.NewContentBlockDeltaEvent(1, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: `{"tz":"Asia/Shanghai"}`,
|
||||
})
|
||||
chunks = e.EncodeEvent(secondDelta)
|
||||
require.NotEmpty(t, chunks)
|
||||
s = string(chunks[0])
|
||||
assert.Contains(t, s, `"index":1`)
|
||||
assert.Contains(t, s, "get_time")
|
||||
assert.Contains(t, s, "Asia/Shanghai")
|
||||
|
||||
subsequentDelta0 := canonical.NewContentBlockDeltaEvent(0, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: `"more_data"`,
|
||||
})
|
||||
chunks = e.EncodeEvent(subsequentDelta0)
|
||||
require.NotEmpty(t, chunks)
|
||||
s = string(chunks[0])
|
||||
assert.Contains(t, s, `"index":0`)
|
||||
assert.NotContains(t, s, "get_weather")
|
||||
assert.Contains(t, s, "more_data")
|
||||
|
||||
subsequentDelta1 := canonical.NewContentBlockDeltaEvent(1, canonical.StreamDelta{
|
||||
Type: string(canonical.DeltaTypeInputJSON),
|
||||
PartialJSON: `"more_time"`,
|
||||
})
|
||||
chunks = e.EncodeEvent(subsequentDelta1)
|
||||
require.NotEmpty(t, chunks)
|
||||
s = string(chunks[0])
|
||||
assert.Contains(t, s, `"index":1`)
|
||||
assert.Contains(t, s, "more_time")
|
||||
}
|
||||
438
backend/internal/conversion/openai/supplemental_test.go
Normal file
438
backend/internal/conversion/openai/supplemental_test.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecodeEmbeddingRequest(t *testing.T) {
|
||||
body := []byte(`{"model":"text-embedding-3-small","input":"hello world","encoding_format":"float","dimensions":256}`)
|
||||
req, err := decodeEmbeddingRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "text-embedding-3-small", req.Model)
|
||||
assert.Equal(t, "hello world", req.Input)
|
||||
assert.Equal(t, "float", req.EncodingFormat)
|
||||
require.NotNil(t, req.Dimensions)
|
||||
assert.Equal(t, 256, *req.Dimensions)
|
||||
}
|
||||
|
||||
func TestDecodeEmbeddingRequest_ArrayInput(t *testing.T) {
|
||||
body := []byte(`{"model":"text-embedding","input":["hello","world"]}`)
|
||||
req, err := decodeEmbeddingRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "text-embedding", req.Model)
|
||||
inputArr, ok := req.Input.([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, inputArr, 2)
|
||||
}
|
||||
|
||||
func TestDecodeEmbeddingRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeEmbeddingRequest([]byte(`invalid`))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeEmbeddingResponse(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"object": "list",
|
||||
"data": [{"index": 0, "embedding": [0.1, 0.2, 0.3]}],
|
||||
"model": "text-embedding-3-small",
|
||||
"usage": {"prompt_tokens": 5, "total_tokens": 5}
|
||||
}`)
|
||||
resp, err := decodeEmbeddingResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "text-embedding-3-small", resp.Model)
|
||||
assert.Len(t, resp.Data, 1)
|
||||
assert.Equal(t, 0, resp.Data[0].Index)
|
||||
assert.Equal(t, 5, resp.Usage.PromptTokens)
|
||||
}
|
||||
|
||||
func TestDecodeRerankRequest(t *testing.T) {
|
||||
topN := 3
|
||||
returnDocs := true
|
||||
body := []byte(`{"model":"rerank-1","query":"what is AI","documents":["doc1","doc2"],"top_n":3,"return_documents":true}`)
|
||||
req, err := decodeRerankRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rerank-1", req.Model)
|
||||
assert.Equal(t, "what is AI", req.Query)
|
||||
assert.Equal(t, []string{"doc1", "doc2"}, req.Documents)
|
||||
require.NotNil(t, req.TopN)
|
||||
assert.Equal(t, topN, *req.TopN)
|
||||
require.NotNil(t, req.ReturnDocuments)
|
||||
assert.Equal(t, returnDocs, *req.ReturnDocuments)
|
||||
}
|
||||
|
||||
func TestDecodeRerankResponse(t *testing.T) {
|
||||
doc := "relevant doc"
|
||||
body := []byte(`{
|
||||
"results": [{"index": 0, "relevance_score": 0.95, "document": "relevant doc"}],
|
||||
"model": "rerank-1"
|
||||
}`)
|
||||
resp, err := decodeRerankResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "rerank-1", resp.Model)
|
||||
assert.Len(t, resp.Results, 1)
|
||||
assert.Equal(t, 0, resp.Results[0].Index)
|
||||
assert.InDelta(t, 0.95, resp.Results[0].RelevanceScore, 0.001)
|
||||
require.NotNil(t, resp.Results[0].Document)
|
||||
assert.Equal(t, doc, *resp.Results[0].Document)
|
||||
}
|
||||
|
||||
func TestDecodeModelInfoResponse(t *testing.T) {
|
||||
body := []byte(`{"id":"gpt-4","object":"model","created":1700000000,"owned_by":"openai"}`)
|
||||
info, err := decodeModelInfoResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", info.ID)
|
||||
assert.Equal(t, int64(1700000000), info.Created)
|
||||
assert.Equal(t, "openai", info.OwnedBy)
|
||||
}
|
||||
|
||||
func TestEncodeEmbeddingRequest(t *testing.T) {
|
||||
req := &canonical.CanonicalEmbeddingRequest{
|
||||
Model: "text-embedding-3-small",
|
||||
Input: "hello",
|
||||
EncodingFormat: "float",
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-embedding-model")
|
||||
|
||||
body, err := encodeEmbeddingRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-embedding-model", result["model"])
|
||||
assert.Equal(t, "hello", result["input"])
|
||||
assert.Equal(t, "float", result["encoding_format"])
|
||||
}
|
||||
|
||||
func TestEncodeEmbeddingRequest_WithDimensions(t *testing.T) {
|
||||
dims := 256
|
||||
req := &canonical.CanonicalEmbeddingRequest{
|
||||
Model: "text-embedding",
|
||||
Input: "test",
|
||||
Dimensions: &dims,
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeEmbeddingRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, float64(256), result["dimensions"])
|
||||
}
|
||||
|
||||
func TestEncodeEmbeddingResponse(t *testing.T) {
|
||||
resp := &canonical.CanonicalEmbeddingResponse{
|
||||
Data: []canonical.EmbeddingData{{Index: 0, Embedding: []float64{0.1, 0.2}}},
|
||||
Model: "text-embedding",
|
||||
Usage: canonical.EmbeddingUsage{PromptTokens: 3, TotalTokens: 3},
|
||||
}
|
||||
|
||||
body, err := encodeEmbeddingResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
assert.Equal(t, "text-embedding", result["model"])
|
||||
}
|
||||
|
||||
func TestEncodeRerankRequest(t *testing.T) {
|
||||
topN := 5
|
||||
req := &canonical.CanonicalRerankRequest{
|
||||
Model: "rerank-1",
|
||||
Query: "what is AI",
|
||||
Documents: []string{"doc1", "doc2"},
|
||||
TopN: &topN,
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "my-rerank-model")
|
||||
|
||||
body, err := encodeRerankRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "my-rerank-model", result["model"])
|
||||
assert.Equal(t, "what is AI", result["query"])
|
||||
}
|
||||
|
||||
func TestEncodeRerankResponse(t *testing.T) {
|
||||
doc := "relevant passage"
|
||||
resp := &canonical.CanonicalRerankResponse{
|
||||
Results: []canonical.RerankResult{
|
||||
{Index: 0, RelevanceScore: 0.95, Document: &doc},
|
||||
},
|
||||
Model: "rerank-1",
|
||||
}
|
||||
|
||||
body, err := encodeRerankResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "rerank-1", result["model"])
|
||||
results, okr := result["results"].([]any)
|
||||
require.True(t, okr)
|
||||
assert.Len(t, results, 1)
|
||||
}
|
||||
|
||||
func TestEncodeModelInfoResponse(t *testing.T) {
|
||||
info := &canonical.CanonicalModelInfo{
|
||||
ID: "gpt-4",
|
||||
Name: "GPT-4",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "openai",
|
||||
}
|
||||
|
||||
body, err := encodeModelInfoResponse(info)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "gpt-4", result["id"])
|
||||
assert.Equal(t, "model", result["object"])
|
||||
}
|
||||
|
||||
func TestDecodeEmbeddingResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeEmbeddingResponse([]byte(`invalid`))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeRerankRequest_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRerankRequest([]byte(`invalid`))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeRerankResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeRerankResponse([]byte(`invalid`))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeModelInfoResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := decodeModelInfoResponse([]byte(`invalid`))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ThinkingNone(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "disabled", req.Thinking.Type)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_ThinkingMinimal(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"minimal"}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req.Thinking)
|
||||
assert.Equal(t, "enabled", req.Thinking.Type)
|
||||
assert.Equal(t, "low", req.Thinking.Effort)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_OutputFormat_Text(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"response_format":{"type":"text"}}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, req.OutputFormat)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_DeprecatedFunctionCall(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"function_call":"auto","functions":[{"name":"fn1","parameters":{}}]}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "auto", req.ToolChoice.Type)
|
||||
assert.Len(t, req.Tools, 1)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_FunctionMessage(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "function", "name": "get_weather", "content": "sunny"}
|
||||
]
|
||||
}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, req.Messages, 2)
|
||||
assert.Equal(t, canonical.RoleTool, req.Messages[1].Role)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_StopString(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stop":"END"}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"END"}, req.Parameters.StopSequences)
|
||||
}
|
||||
|
||||
func TestDecodeRequest_StopEmptyString(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stop":""}`)
|
||||
req, err := decodeRequest(body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, req.Parameters.StopSequences)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_EmptyChoices(t *testing.T) {
|
||||
body := []byte(`{"id":"resp-1","model":"gpt-4","choices":[],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`)
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "resp-1", resp.ID)
|
||||
assert.Len(t, resp.Content, 1)
|
||||
assert.Equal(t, "", resp.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeResponse_FunctionCallFinishReason(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"id":"r1","model":"gpt-4",
|
||||
"choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"function_call"}],
|
||||
"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}
|
||||
}`)
|
||||
resp, err := decodeResponse(body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, canonical.StopReasonToolUse, *resp.StopReason)
|
||||
}
|
||||
|
||||
func TestEncodeRequest_DisabledThinking(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
Thinking: &canonical.ThinkingConfig{Type: "disabled"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "none", result["reasoning_effort"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_OutputFormat_JSONObject(t *testing.T) {
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
OutputFormat: &canonical.OutputFormat{Type: "json_object"},
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
rf, ok := result["response_format"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "json_object", rf["type"])
|
||||
}
|
||||
|
||||
func TestEncodeRequest_PublicFields(t *testing.T) {
|
||||
parallel := true
|
||||
req := &canonical.CanonicalRequest{
|
||||
Model: "gpt-4",
|
||||
Messages: []canonical.CanonicalMessage{{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("hi")}}},
|
||||
UserID: "user-123",
|
||||
ParallelToolUse: ¶llel,
|
||||
}
|
||||
provider := conversion.NewTargetProvider("", "key", "model")
|
||||
|
||||
body, err := encodeRequest(req, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "user-123", result["user"])
|
||||
assert.Equal(t, true, result["parallel_tool_calls"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
cache := 80
|
||||
reasoning := 20
|
||||
sr := canonical.StopReasonEndTurn
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CacheReadTokens: &cache,
|
||||
ReasoningTokens: &reasoning,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["prompt_tokens"])
|
||||
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(80), ptd["cached_tokens"])
|
||||
ctd, ok := usage["completion_tokens_details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(20), ctd["reasoning_tokens"])
|
||||
}
|
||||
|
||||
func TestEncodeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stopReason canonical.StopReason
|
||||
want string
|
||||
}{
|
||||
{"end_turn→stop", canonical.StopReasonEndTurn, "stop"},
|
||||
{"max_tokens→length", canonical.StopReasonMaxTokens, "length"},
|
||||
{"tool_use→tool_calls", canonical.StopReasonToolUse, "tool_calls"},
|
||||
{"content_filter→content_filter", canonical.StopReasonContentFilter, "content_filter"},
|
||||
{"stop_sequence→stop", canonical.StopReasonStopSequence, "stop"},
|
||||
{"refusal→stop", canonical.StopReasonRefusal, "stop"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sr := tt.stopReason
|
||||
resp := &canonical.CanonicalResponse{
|
||||
ID: "r1",
|
||||
Model: "gpt-4",
|
||||
Content: []canonical.ContentBlock{canonical.NewTextBlock("ok")},
|
||||
StopReason: &sr,
|
||||
Usage: canonical.CanonicalUsage{},
|
||||
}
|
||||
body, err := encodeResponse(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
choice, okc := choices[0].(map[string]any)
|
||||
require.True(t, okc)
|
||||
assert.Equal(t, tt.want, choice["finish_reason"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapErrorCode_AllCodes(t *testing.T) {
|
||||
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeInvalidInput))
|
||||
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeMissingRequiredField))
|
||||
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeIncompatibleFeature))
|
||||
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeFieldMappingFailure))
|
||||
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeToolCallParseError))
|
||||
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeJSONParseError))
|
||||
assert.Equal(t, "invalid_request_error", mapErrorCode(conversion.ErrorCodeProtocolConstraint))
|
||||
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeStreamStateError))
|
||||
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeUTF8DecodeError))
|
||||
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeEncodingFailure))
|
||||
assert.Equal(t, "server_error", mapErrorCode(conversion.ErrorCodeInterfaceNotSupported))
|
||||
}
|
||||
245
backend/internal/conversion/openai/types.go
Normal file
245
backend/internal/conversion/openai/types.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package openai
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ChatCompletionRequest OpenAI Chat Completion 请求
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// 已废弃字段
|
||||
Functions []FunctionDef `json:"functions,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// Message OpenAI 消息
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
// 已废弃
|
||||
FunctionCall *FunctionCallMsg `json:"function_call,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall OpenAI 工具调用
|
||||
type ToolCall struct {
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Custom *CustomTool `json:"custom,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionCall OpenAI 函数调用
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// CustomTool 自定义工具
|
||||
type CustomTool struct {
|
||||
Name string `json:"name"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
// FunctionCallMsg 已废弃的函数调用消息
|
||||
type FunctionCallMsg struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// Tool OpenAI 工具定义
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function *FunctionDef `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionDef OpenAI 函数定义
|
||||
type FunctionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseFormat OpenAI 响应格式
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
JSONSchema *JSONSchemaDef `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
// JSONSchemaDef JSON Schema 定义
|
||||
type JSONSchemaDef struct {
|
||||
Name string `json:"name"`
|
||||
Schema json.RawMessage `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// StreamOptions 流式选项
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionResponse OpenAI Chat Completion 响应
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
// Choice OpenAI 选择项
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Message *Message `json:"message,omitempty"`
|
||||
Delta *Message `json:"delta,omitempty"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs any `json:"logprobs,omitempty"`
|
||||
}
|
||||
|
||||
// Usage OpenAI 用量
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// PromptTokensDetails 提示 Token 详情
|
||||
type PromptTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// CompletionTokensDetails 完成 Token 详情
|
||||
type CompletionTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// StreamChunk OpenAI 流式 chunk
|
||||
type StreamChunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
}
|
||||
|
||||
// ModelsResponse OpenAI 模型列表响应
|
||||
type ModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []ModelItem `json:"data"`
|
||||
}
|
||||
|
||||
// ModelItem OpenAI 模型项
|
||||
type ModelItem struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// ModelInfoResponse OpenAI 模型详情响应
|
||||
type ModelInfoResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest OpenAI 嵌入请求
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse OpenAI 嵌入响应
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage EmbeddingUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// EmbeddingData 嵌入数据项
|
||||
type EmbeddingData struct {
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"`
|
||||
}
|
||||
|
||||
// EmbeddingUsage 嵌入用量
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// RerankRequest OpenAI 重排序请求
|
||||
type RerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
}
|
||||
|
||||
// RerankResponse OpenAI 重排序响应
|
||||
type RerankResponse struct {
|
||||
Results []RerankResult `json:"results"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// RerankResult 重排序结果项
|
||||
type RerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document *string `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse OpenAI 错误响应
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail 错误详情
|
||||
type ErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param any `json:"param"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
19
backend/internal/conversion/provider.go
Normal file
19
backend/internal/conversion/provider.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package conversion
|
||||
|
||||
// TargetProvider 目标上游供应商信息
|
||||
type TargetProvider struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
ModelName string `json:"model_name"`
|
||||
AdapterConfig map[string]any `json:"adapter_config,omitempty"`
|
||||
}
|
||||
|
||||
// NewTargetProvider 创建目标供应商
|
||||
func NewTargetProvider(baseURL, apiKey, modelName string) *TargetProvider {
|
||||
return &TargetProvider{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
ModelName: modelName,
|
||||
AdapterConfig: make(map[string]any),
|
||||
}
|
||||
}
|
||||
265
backend/internal/conversion/stream.go
Normal file
265
backend/internal/conversion/stream.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder 流式解码器接口
|
||||
type StreamDecoder interface {
|
||||
ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent
|
||||
Flush() []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
// StreamEncoder 流式编码器接口
|
||||
type StreamEncoder interface {
|
||||
EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte
|
||||
Flush() [][]byte
|
||||
}
|
||||
|
||||
// StreamConverter 流式转换器接口
|
||||
type StreamConverter interface {
|
||||
ProcessChunk(rawChunk []byte) [][]byte
|
||||
Flush() [][]byte
|
||||
}
|
||||
|
||||
// PassthroughStreamConverter 同协议透传流式转换器
|
||||
type PassthroughStreamConverter struct{}
|
||||
|
||||
// NewPassthroughStreamConverter 创建透传流式转换器
|
||||
func NewPassthroughStreamConverter() *PassthroughStreamConverter {
|
||||
return &PassthroughStreamConverter{}
|
||||
}
|
||||
|
||||
// ProcessChunk 直接传递原始字节
|
||||
func (c *PassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
return [][]byte{rawChunk}
|
||||
}
|
||||
|
||||
// Flush 无缓冲数据
|
||||
func (c *PassthroughStreamConverter) Flush() [][]byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||
// 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
type SmartPassthroughStreamConverter struct {
|
||||
adapter ProtocolAdapter
|
||||
modelOverride string
|
||||
interfaceType InterfaceType
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||
func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride string, interfaceType InterfaceType) *SmartPassthroughStreamConverter {
|
||||
return &SmartPassthroughStreamConverter{
|
||||
adapter: adapter,
|
||||
modelOverride: modelOverride,
|
||||
interfaceType: interfaceType,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
if len(rawChunk) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.buffer = append(c.buffer, rawChunk...)
|
||||
frames, rest := splitSSEFrames(c.buffer)
|
||||
c.buffer = rest
|
||||
|
||||
result := make([][]byte, 0, len(frames))
|
||||
for _, frame := range frames {
|
||||
result = append(result, c.rewriteFrame(frame))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
|
||||
payload, ok := sseFrameDataPayload(frame)
|
||||
if !ok || strings.TrimSpace(payload) == "[DONE]" {
|
||||
return frame
|
||||
}
|
||||
|
||||
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
return frame
|
||||
}
|
||||
|
||||
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
|
||||
}
|
||||
|
||||
// Flush 输出未形成完整 frame 的剩余数据
|
||||
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||
if len(c.buffer) == 0 {
|
||||
return nil
|
||||
}
|
||||
frame := append([]byte(nil), c.buffer...)
|
||||
c.buffer = nil
|
||||
return [][]byte{c.rewriteFrame(frame)}
|
||||
}
|
||||
|
||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||
type CanonicalStreamConverter struct {
|
||||
decoder StreamDecoder
|
||||
encoder StreamEncoder
|
||||
chain *MiddlewareChain
|
||||
ctx ConversionContext
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
modelOverride string
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverter 创建规范流式转换器
|
||||
func NewCanonicalStreamConverter(decoder StreamDecoder, encoder StreamEncoder) *CanonicalStreamConverter {
|
||||
return &CanonicalStreamConverter{
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCanonicalStreamConverterWithMiddleware 创建带中间件的规范流式转换器
|
||||
func NewCanonicalStreamConverterWithMiddleware(decoder StreamDecoder, encoder StreamEncoder, chain *MiddlewareChain, ctx ConversionContext, clientProtocol, providerProtocol, modelOverride string) *CanonicalStreamConverter {
|
||||
return &CanonicalStreamConverter{
|
||||
decoder: decoder,
|
||||
encoder: encoder,
|
||||
chain: chain,
|
||||
ctx: ctx,
|
||||
clientProtocol: clientProtocol,
|
||||
providerProtocol: providerProtocol,
|
||||
modelOverride: modelOverride,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 解码 → 中间件 → modelOverride → 编码管道
|
||||
func (c *CanonicalStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
events := c.decoder.ProcessChunk(rawChunk)
|
||||
var result [][]byte
|
||||
for i := range events {
|
||||
if c.chain != nil {
|
||||
processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush 刷新解码器和编码器缓冲区
|
||||
func (c *CanonicalStreamConverter) Flush() [][]byte {
|
||||
events := c.decoder.Flush()
|
||||
var result [][]byte
|
||||
for i := range events {
|
||||
if c.chain != nil {
|
||||
processed, err := c.chain.ApplyStreamEvent(&events[i], c.clientProtocol, c.providerProtocol, &c.ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
events[i] = *processed
|
||||
}
|
||||
c.applyModelOverride(&events[i])
|
||||
chunks := c.encoder.EncodeEvent(events[i])
|
||||
result = append(result, chunks...)
|
||||
}
|
||||
encoderChunks := c.encoder.Flush()
|
||||
result = append(result, encoderChunks...)
|
||||
return result
|
||||
}
|
||||
|
||||
// applyModelOverride 在跨协议场景下覆写流式事件中的 Model 字段
|
||||
func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.CanonicalStreamEvent) {
|
||||
if c.modelOverride != "" && event.Message != nil {
|
||||
event.Message.Model = c.modelOverride
|
||||
}
|
||||
}
|
||||
|
||||
func splitSSEFrames(data []byte) ([][]byte, []byte) {
|
||||
var frames [][]byte
|
||||
for len(data) > 0 {
|
||||
idx, sepLen := findSSEFrameSeparator(data)
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
end := idx + sepLen
|
||||
frames = append(frames, append([]byte(nil), data[:end]...))
|
||||
data = data[end:]
|
||||
}
|
||||
return frames, data
|
||||
}
|
||||
|
||||
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||
lf := bytes.Index(data, []byte("\n\n"))
|
||||
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0 && crlf < 0:
|
||||
return -1, 0
|
||||
case lf < 0:
|
||||
return crlf, 4
|
||||
case crlf < 0:
|
||||
return lf, 2
|
||||
case crlf <= lf:
|
||||
return crlf, 4
|
||||
default:
|
||||
return lf, 2
|
||||
}
|
||||
}
|
||||
|
||||
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
var dataLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
value := strings.TrimPrefix(line, "data:")
|
||||
if strings.HasPrefix(value, " ") {
|
||||
value = value[1:]
|
||||
}
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(dataLines, "\n"), true
|
||||
}
|
||||
|
||||
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
|
||||
lineEnding, separator := sseLineEnding(frame)
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
out := make([]string, 0, len(lines)+1)
|
||||
dataWritten := false
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
if !dataWritten {
|
||||
for _, dataLine := range strings.Split(data, "\n") {
|
||||
out = append(out, "data: "+dataLine)
|
||||
}
|
||||
dataWritten = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
out = append(out, line)
|
||||
}
|
||||
if !dataWritten {
|
||||
out = append(out, "data: "+data)
|
||||
}
|
||||
return []byte(strings.Join(out, lineEnding) + separator)
|
||||
}
|
||||
|
||||
func sseLineEnding(frame []byte) (string, string) {
|
||||
if bytes.Contains(frame, []byte("\r\n")) {
|
||||
return "\r\n", "\r\n\r\n"
|
||||
}
|
||||
return "\n", "\n\n"
|
||||
}
|
||||
199
backend/internal/conversion/stream_test.go
Normal file
199
backend/internal/conversion/stream_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package conversion
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPassthroughStreamConverter_ProcessChunk(t *testing.T) {
|
||||
converter := NewPassthroughStreamConverter()
|
||||
data := []byte("hello world")
|
||||
result := converter.ProcessChunk(data)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, data, result[0])
|
||||
}
|
||||
|
||||
func TestPassthroughStreamConverter_Flush(t *testing.T) {
|
||||
converter := NewPassthroughStreamConverter()
|
||||
result := converter.Flush()
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
// mockStreamDecoder 模拟流式解码器
|
||||
type mockStreamDecoder struct {
|
||||
chunks [][]canonical.CanonicalStreamEvent
|
||||
flush []canonical.CanonicalStreamEvent
|
||||
}
|
||||
|
||||
// ProcessChunk 弹出下一个分片的事件
|
||||
func (d *mockStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
if len(d.chunks) == 0 {
|
||||
return nil
|
||||
}
|
||||
events := d.chunks[0]
|
||||
d.chunks = d.chunks[1:]
|
||||
return events
|
||||
}
|
||||
|
||||
// Flush 返回刷新事件
|
||||
func (d *mockStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
return d.flush
|
||||
}
|
||||
|
||||
// mockStreamEncoder 模拟流式编码器
|
||||
type mockStreamEncoder struct {
|
||||
events [][]byte
|
||||
flush [][]byte
|
||||
}
|
||||
|
||||
// EncodeEvent 返回编码后的事件
|
||||
func (e *mockStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEvent) [][]byte {
|
||||
if len(e.events) == 0 {
|
||||
return nil
|
||||
}
|
||||
return e.events
|
||||
}
|
||||
|
||||
// Flush 返回编码器刷新数据
|
||||
func (e *mockStreamEncoder) Flush() [][]byte {
|
||||
return e.flush
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_ProcessChunk(t *testing.T) {
|
||||
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
|
||||
decoder := &mockStreamDecoder{
|
||||
chunks: [][]canonical.CanonicalStreamEvent{{event}},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: test\n\n")},
|
||||
}
|
||||
|
||||
converter := NewCanonicalStreamConverter(decoder, encoder)
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, []byte("data: test\n\n"), result[0])
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_WithMiddleware(t *testing.T) {
|
||||
var records []string
|
||||
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
|
||||
decoder := &mockStreamDecoder{
|
||||
chunks: [][]canonical.CanonicalStreamEvent{{event}},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: ok\n\n")},
|
||||
}
|
||||
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&recordingMiddleware{name: "mw1", records: &records})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, []string{"stream:mw1"}, records)
|
||||
assert.Equal(t, []byte("data: ok\n\n"), result[0])
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_Flush(t *testing.T) {
|
||||
decoder := &mockStreamDecoder{
|
||||
flush: []canonical.CanonicalStreamEvent{
|
||||
canonical.NewMessageStopEvent(),
|
||||
},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: stop\n\n")},
|
||||
flush: [][]byte{[]byte("data: flush\n\n")},
|
||||
}
|
||||
|
||||
converter := NewCanonicalStreamConverter(decoder, encoder)
|
||||
result := converter.Flush()
|
||||
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, []byte("data: stop\n\n"), result[0])
|
||||
assert.Equal(t, []byte("data: flush\n\n"), result[1])
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_EmptyDecoder(t *testing.T) {
|
||||
decoder := &mockStreamDecoder{}
|
||||
encoder := &mockStreamEncoder{}
|
||||
|
||||
converter := NewCanonicalStreamConverter(decoder, encoder)
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_MiddlewareError_Continue(t *testing.T) {
|
||||
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
|
||||
decoder := &mockStreamDecoder{
|
||||
chunks: [][]canonical.CanonicalStreamEvent{{event}},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: ok\n\n")},
|
||||
}
|
||||
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.ProcessChunk([]byte("raw"))
|
||||
|
||||
assert.Nil(t, result, "middleware error should cause the event to be skipped (continue)")
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_Flush_MiddlewareError_Continue(t *testing.T) {
|
||||
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
|
||||
decoder := &mockStreamDecoder{
|
||||
flush: []canonical.CanonicalStreamEvent{event},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: ok\n\n")},
|
||||
flush: [][]byte{[]byte("data: encoder_flush\n\n")},
|
||||
}
|
||||
|
||||
chain := NewMiddlewareChain()
|
||||
chain.Use(&errorMiddleware{})
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
|
||||
converter := NewCanonicalStreamConverterWithMiddleware(decoder, encoder, chain, *ctx, "openai", "anthropic", "")
|
||||
result := converter.Flush()
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, []byte("data: encoder_flush\n\n"), result[0])
|
||||
}
|
||||
|
||||
func TestCanonicalStreamConverter_Flush_DecoderAndEncoderBothProduce(t *testing.T) {
|
||||
event := canonical.NewMessageStartEvent("id-1", "gpt-4")
|
||||
decoder := &mockStreamDecoder{
|
||||
flush: []canonical.CanonicalStreamEvent{event},
|
||||
}
|
||||
encoder := &mockStreamEncoder{
|
||||
events: [][]byte{[]byte("data: decoder_flush\n\n")},
|
||||
flush: [][]byte{[]byte("data: encoder_flush\n\n")},
|
||||
}
|
||||
|
||||
converter := NewCanonicalStreamConverter(decoder, encoder)
|
||||
result := converter.Flush()
|
||||
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, []byte("data: decoder_flush\n\n"), result[0])
|
||||
assert.Equal(t, []byte("data: encoder_flush\n\n"), result[1])
|
||||
}
|
||||
|
||||
type errorMiddleware struct{}
|
||||
|
||||
func (m *errorMiddleware) Intercept(req *canonical.CanonicalRequest, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
return nil, fmt.Errorf("middleware error")
|
||||
}
|
||||
|
||||
func (m *errorMiddleware) InterceptStreamEvent(event *canonical.CanonicalStreamEvent, clientProtocol, providerProtocol string, ctx *ConversionContext) (*canonical.CanonicalStreamEvent, error) {
|
||||
return nil, fmt.Errorf("stream middleware error")
|
||||
}
|
||||
151
backend/internal/database/database.go
Normal file
151
backend/internal/database/database.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
moduleLogger := pkglogger.WithModule(zapLogger, "database")
|
||||
|
||||
db, err := initDB(cfg, moduleLogger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
configurePool(db, cfg, moduleLogger)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Close(db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
gormLogger := pkglogger.NewGormLogger(zapLogger)
|
||||
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
}
|
||||
|
||||
switch cfg.Driver {
|
||||
case "mysql":
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 MySQL 数据库",
|
||||
zap.String("host", cfg.Host),
|
||||
zap.Int("port", cfg.Port),
|
||||
zap.String("database", cfg.DBName))
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
default:
|
||||
dbDir := filepath.Dir(cfg.Path)
|
||||
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
|
||||
}
|
||||
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
gooseDialect := "sqlite3"
|
||||
migrationsSubDir := "sqlite"
|
||||
if driver == "mysql" {
|
||||
gooseDialect = "mysql"
|
||||
migrationsSubDir = "mysql"
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir(driver)
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("执行数据库迁移",
|
||||
zap.String("dialect", gooseDialect),
|
||||
zap.String("dir", migrationsSubDir))
|
||||
}
|
||||
|
||||
if err := goose.SetDialect(gooseDialect); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
|
||||
if cfg.Driver == "sqlite" {
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("数据库连接池配置",
|
||||
zap.Int("max_idle_conns", cfg.MaxIdleConns),
|
||||
zap.Int("max_open_conns", cfg.MaxOpenConns),
|
||||
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
|
||||
}
|
||||
}
|
||||
|
||||
func getMigrationsDir(driver string) string {
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
subDir := "sqlite"
|
||||
if driver == "mysql" {
|
||||
subDir = "mysql"
|
||||
}
|
||||
dir := filepath.Join(filepath.Dir(filename), "..", "..", "migrations", subDir)
|
||||
if abs, err := filepath.Abs(dir); err == nil {
|
||||
return abs
|
||||
}
|
||||
}
|
||||
return "./migrations"
|
||||
}
|
||||
|
||||
func BuildDSN(cfg *config.DatabaseConfig) string {
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||
}
|
||||
78
backend/internal/database/database_test.go
Normal file
78
backend/internal/database/database_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestInit_SQLite(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
defer Close(db)
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sqlDB)
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: filepath.Join(dir, "test.db"),
|
||||
MaxIdleConns: 5,
|
||||
MaxOpenConns: 10,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
Close(db)
|
||||
}
|
||||
|
||||
func TestBuildDSN(t *testing.T) {
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "mysql",
|
||||
Host: "db.example.com",
|
||||
Port: 3306,
|
||||
User: "nexuser",
|
||||
Password: "secretpass",
|
||||
DBName: "nexdb",
|
||||
}
|
||||
|
||||
dsn := BuildDSN(cfg)
|
||||
assert.Equal(t, "nexuser:secretpass@tcp(db.example.com:3306)/nexdb?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||
}
|
||||
|
||||
func TestBuildDSN_EmptyPassword(t *testing.T) {
|
||||
cfg := &config.DatabaseConfig{
|
||||
Driver: "mysql",
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
User: "root",
|
||||
DBName: "nex",
|
||||
}
|
||||
|
||||
dsn := BuildDSN(cfg)
|
||||
assert.Equal(t, "root:@tcp(localhost:3306)/nex?charset=utf8mb4&parseTime=true&loc=Local", dsn)
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
// Model 模型领域模型
|
||||
"nex/backend/pkg/modelid"
|
||||
)
|
||||
|
||||
// Model 模型领域模型(id 为 UUID 自动生成)
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
@@ -10,3 +14,8 @@ type Model struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// UnifiedModelID 返回统一模型 ID(格式:provider_id/model_name)
|
||||
func (m *Model) UnifiedModelID() string {
|
||||
return modelid.FormatUnifiedModelID(m.ProviderID, m.ModelName)
|
||||
}
|
||||
|
||||
@@ -8,16 +8,8 @@ type Provider struct {
|
||||
Name string `json:"name"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Protocol string `json:"protocol"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MaskAPIKey 掩码 API Key(仅显示最后 4 个字符)
|
||||
func (p *Provider) MaskAPIKey() {
|
||||
if len(p.APIKey) > 4 {
|
||||
p.APIKey = "***" + p.APIKey[len(p.APIKey)-4:]
|
||||
} else {
|
||||
p.APIKey = "***"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/anthropic"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// AnthropicHandler Anthropic 协议处理器
|
||||
type AnthropicHandler struct {
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewAnthropicHandler 创建 Anthropic 处理器
|
||||
func NewAnthropicHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *AnthropicHandler {
|
||||
return &AnthropicHandler{
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleMessages 处理 Messages 请求
|
||||
func (h *AnthropicHandler) HandleMessages(c *gin.Context) {
|
||||
var req anthropic.MessagesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 请求验证
|
||||
if validationErrors := anthropic.ValidateRequest(&req); validationErrors != nil {
|
||||
errMsg := formatValidationErrors(validationErrors)
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: errMsg,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.checkMultimodalContent(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
openaiReq, err := anthropic.ConvertRequest(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "invalid_request_error",
|
||||
Message: "请求转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(openaiReq.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, openaiReq, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, openaiReq, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleNonStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
openaiResp, err := h.client.SendRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
anthropicResp, err := anthropic.ConvertResponse(openaiResp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "响应转换失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, anthropicResp)
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleStreamRequest(c *gin.Context, openaiReq *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), openaiReq, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
converter := anthropic.NewStreamConverter(
|
||||
fmt.Sprintf("msg_%s", routeResult.Provider.ID),
|
||||
openaiReq.Model,
|
||||
)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
|
||||
chunk, err := openai.NewAdapter().ParseStreamChunk(event.Data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
anthropicEvents, err := converter.ConvertChunk(chunk)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ae := range anthropicEvents {
|
||||
eventStr, err := anthropic.SerializeEvent(ae)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
writer.WriteString(eventStr)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, openaiReq.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) checkMultimodalContent(req *anthropic.MessagesRequest) error {
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
if block.Type == "image" {
|
||||
return fmt.Errorf("MVP 不支持多模态内容(图片)")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *AnthropicHandler) handleError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "not_found_error",
|
||||
Message: appErr.Message,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, anthropic.ErrorResponse{
|
||||
Type: "error",
|
||||
Error: anthropic.ErrorDetail{
|
||||
Type: "internal_error",
|
||||
Message: "内部错误: " + err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
209
backend/internal/handler/handler_supplemental_test.go
Normal file
209
backend/internal/handler/handler_supplemental_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://api.test.com",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
var result domain.Provider
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "p1", result.ID)
|
||||
assert.Equal(t, "sk-test", result.APIKey)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_WithProtocol(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://api.test.com",
|
||||
"protocol": "anthropic",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "Updated", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InvalidBody(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", nil)
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_DeleteProvider(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("p1")).Return(nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/api/providers/p1", bytes.NewReader([]byte{}))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.DeleteProvider(c)
|
||||
assert.True(t, w.Code == 204 || w.Code == 200)
|
||||
}
|
||||
|
||||
func TestModelHandler_DeleteModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Delete(gomock.Eq("m1")).Return(nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/api/models/m1", bytes.NewReader([]byte{}))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.DeleteModel(c)
|
||||
assert.True(t, w.Code == 204 || w.Code == 200)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "p1",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
var result domain.Model
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.NotEmpty(t, result.ID)
|
||||
}
|
||||
|
||||
func TestModelHandler_GetModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
|
||||
|
||||
h.GetModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result domain.Model
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "gpt-4", result.ModelName)
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ModelName: "gpt-4o"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"model_name": "gpt-4o"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
@@ -2,19 +2,22 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -22,129 +25,12 @@ func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// ============ Mock 实现 ============
|
||||
|
||||
type mockRoutingService struct {
|
||||
result *domain.RouteResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRoutingService) Route(modelName string) (*domain.RouteResult, error) {
|
||||
return m.result, m.err
|
||||
}
|
||||
|
||||
type mockStatsService struct {
|
||||
err error
|
||||
stats []domain.UsageStats
|
||||
aggrResult []map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockStatsService) Record(providerID, modelName string) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockStatsService) Get(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
return m.stats, nil
|
||||
}
|
||||
func (m *mockStatsService) Aggregate(stats []domain.UsageStats, groupBy string) []map[string]interface{} {
|
||||
return m.aggrResult
|
||||
}
|
||||
|
||||
type mockProviderService struct {
|
||||
provider *domain.Provider
|
||||
providers []domain.Provider
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderService) Create(provider *domain.Provider) error { return m.err }
|
||||
func (m *mockProviderService) Get(id string, maskKey bool) (*domain.Provider, error) {
|
||||
return m.provider, m.err
|
||||
}
|
||||
func (m *mockProviderService) List() ([]domain.Provider, error) { return m.providers, m.err }
|
||||
func (m *mockProviderService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockProviderService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockModelService struct {
|
||||
model *domain.Model
|
||||
models []domain.Model
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockModelService) Create(model *domain.Model) error { return m.err }
|
||||
func (m *mockModelService) Get(id string) (*domain.Model, error) {
|
||||
return m.model, m.err
|
||||
}
|
||||
func (m *mockModelService) List(providerID string) ([]domain.Model, error) {
|
||||
return m.models, m.err
|
||||
}
|
||||
func (m *mockModelService) Update(id string, updates map[string]interface{}) error {
|
||||
return m.err
|
||||
}
|
||||
func (m *mockModelService) Delete(id string) error { return m.err }
|
||||
|
||||
type mockProviderClient struct {
|
||||
resp *openai.ChatCompletionResponse
|
||||
eventChan chan provider.StreamEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) SendRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (*openai.ChatCompletionResponse, error) {
|
||||
return m.resp, m.err
|
||||
}
|
||||
func (m *mockProviderClient) SendStreamRequest(ctx context.Context, req *openai.ChatCompletionRequest, apiKey, baseURL string) (<-chan provider.StreamEvent, error) {
|
||||
return m.eventChan, m.err
|
||||
}
|
||||
|
||||
// ============ OpenAI Handler 测试 ============
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_InvalidJSON(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader([]byte("invalid")))
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_ValidationError(t *testing.T) {
|
||||
h := NewOpenAIHandler(nil, nil, nil)
|
||||
|
||||
// 缺少 model 字段
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestOpenAIHandler_HandleChatCompletions_RouteError(t *testing.T) {
|
||||
routingSvc := &mockRoutingService{err: appErrors.ErrModelNotFound}
|
||||
h := NewOpenAIHandler(nil, routingSvc, nil)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"model": "nonexistent",
|
||||
"messages": []map[string]string{{"role": "user", "content": "hi"}},
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.HandleChatCompletions(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
// ============ Provider Handler 测试 ============
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "p1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -157,12 +43,15 @@ func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
providers: []domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "P1"},
|
||||
{ID: "p2", Name: "P2"},
|
||||
}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -172,14 +61,17 @@ func TestProviderHandler_ListProviders(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []domain.Provider
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
h := NewProviderHandler(&mockProviderService{
|
||||
provider: &domain.Provider{ID: "p1", Name: "P1", APIKey: "***"},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("p1")).Return(&domain.Provider{ID: "p1", Name: "P1", APIKey: "sk-test"}, nil)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -190,10 +82,12 @@ func TestProviderHandler_GetProvider(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ Model Handler 测试 ============
|
||||
|
||||
func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"id": "m1"})
|
||||
w := httptest.NewRecorder()
|
||||
@@ -206,12 +100,15 @@ func TestModelHandler_CreateModel_MissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelHandler_ListModels(t *testing.T) {
|
||||
h := NewModelHandler(&mockModelService{
|
||||
models: []domain.Model{
|
||||
{ID: "m1", ModelName: "gpt-4"},
|
||||
{ID: "m2", ModelName: "gpt-3.5"},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().List(gomock.Eq("")).Return([]domain.Model{
|
||||
{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"},
|
||||
{ID: "m2", ProviderID: "anthropic", ModelName: "claude-3"},
|
||||
}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -219,16 +116,98 @@ func TestModelHandler_ListModels(t *testing.T) {
|
||||
|
||||
h.ListModels(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result []modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
require.Len(t, result, 2)
|
||||
assert.Equal(t, "openai/gpt-4", result[0].UnifiedModelID)
|
||||
assert.Equal(t, "anthropic/claude-3", result[1].UnifiedModelID)
|
||||
}
|
||||
|
||||
// ============ Stats Handler 测试 ============
|
||||
func TestModelHandler_GetModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("GET", "/api/models/m1", nil)
|
||||
|
||||
h.GetModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "m1", result.ID)
|
||||
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).DoAndReturn(func(model *domain.Model) error {
|
||||
model.ID = "mock-uuid-1234"
|
||||
return nil
|
||||
})
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "mock-uuid-1234", result.ID)
|
||||
assert.Equal(t, "openai/gpt-4", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestModelHandler_UpdateModel_UnifiedID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("m1"), gomock.Any()).Return(nil)
|
||||
mockSvc.EXPECT().Get(gomock.Eq("m1")).Return(&domain.Model{ID: "m1", ProviderID: "openai", ModelName: "gpt-4-turbo"}, nil)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"enabled": false})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "m1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/models/m1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateModel(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
var result modelResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "openai/gpt-4-turbo", result.UnifiedModelID)
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
},
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", ModelName: "gpt-4", RequestCount: 10},
|
||||
}, nil)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -239,7 +218,11 @@ func TestStatsHandler_GetStats(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -250,14 +233,17 @@ func TestStatsHandler_GetStats_InvalidDate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
h := NewStatsHandler(&mockStatsService{
|
||||
stats: []domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
},
|
||||
aggrResult: []map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
},
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockStatsService(ctrl)
|
||||
mockSvc.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]domain.UsageStats{
|
||||
{ProviderID: "p1", RequestCount: 10},
|
||||
}, nil)
|
||||
mockSvc.EXPECT().Aggregate(gomock.Any(), gomock.Eq("provider")).Return([]map[string]interface{}{
|
||||
{"provider_id": "p1", "request_count": 10},
|
||||
})
|
||||
h := NewStatsHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -267,8 +253,6 @@ func TestStatsHandler_AggregateStats(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// ============ writeError 测试 ============
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -283,8 +267,194 @@ func TestFormatValidationErrors(t *testing.T) {
|
||||
"model": "模型名称不能为空",
|
||||
"messages": "消息列表不能为空",
|
||||
}
|
||||
result := formatValidationErrors(errs)
|
||||
result := formatMapErrors(errs)
|
||||
require.Contains(t, result, "请求验证失败")
|
||||
require.Contains(t, result, "model")
|
||||
require.Contains(t, result, "messages")
|
||||
}
|
||||
|
||||
func formatMapErrors(errs map[string]string) string {
|
||||
parts := make([]string, 0, len(errs))
|
||||
for field, msg := range errs {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_DuplicatedKey(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrConflict)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"id": "p1",
|
||||
"name": "Test",
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://test.com",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_ProviderNotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrProviderNotFound)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "nonexistent",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商不存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_DuplicateModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(appErrors.ErrDuplicateModel)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 409, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "同一供应商下模型名称已存在")
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
mockSvc.EXPECT().Create(gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"provider_id": "openai",
|
||||
"model_name": "gpt-4",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_NotFound(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(gorm.ErrRecordNotFound)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_ImmutableField(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(appErrors.ErrImmutableField)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "供应商 ID 不允许修改")
|
||||
}
|
||||
|
||||
func TestProviderHandler_UpdateProvider_InternalError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
mockSvc.EXPECT().Update(gomock.Eq("p1"), gomock.Any()).Return(errors.New("database error"))
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"name": "Updated"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "p1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/api/providers/p1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.UpdateProvider(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
}
|
||||
|
||||
func TestModelHandler_CreateModel_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockModelService(ctrl)
|
||||
h := NewModelHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/models", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateModel(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
func TestProviderHandler_CreateProvider_InvalidJSON(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockSvc := mocks.NewMockProviderService(ctrl)
|
||||
h := NewProviderHandler(mockSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/api/providers", bytes.NewReader([]byte("{invalid json")))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.CreateProvider(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// Logging 日志中间件
|
||||
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
var requestIDStr string
|
||||
if id, ok := requestID.(string); ok {
|
||||
requestIDStr = id
|
||||
}
|
||||
logger.Info("请求开始",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.String("query", query),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.Any("request_id", requestID),
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Query(query),
|
||||
pkglogger.ClientIP(c.ClientIP()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
|
||||
c.Next()
|
||||
@@ -29,12 +34,12 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
logger.Info("请求结束",
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Duration("latency", latency),
|
||||
zap.Int("body_size", c.Writer.Size()),
|
||||
zap.Any("request_id", requestID),
|
||||
pkglogger.StatusCode(statusCode),
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Latency(latency),
|
||||
pkglogger.BodySize(c.Writer.Size()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
@@ -22,40 +23,59 @@ func NewModelHandler(modelService service.ModelService) *ModelHandler {
|
||||
return &ModelHandler{modelService: modelService}
|
||||
}
|
||||
|
||||
// modelResponse 模型响应 DTO,扩展 unified_id 字段
|
||||
type modelResponse struct {
|
||||
domain.Model
|
||||
UnifiedModelID string `json:"unified_id"`
|
||||
}
|
||||
|
||||
// newModelResponse 从 domain.Model 构造响应 DTO
|
||||
func newModelResponse(m *domain.Model) modelResponse {
|
||||
return modelResponse{
|
||||
Model: *m,
|
||||
UnifiedModelID: m.UnifiedModelID(),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateModel 创建模型
|
||||
func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
ProviderID string `json:"provider_id" binding:"required"`
|
||||
ModelName string `json:"model_name" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "缺少必需字段: id, provider_id, model_name",
|
||||
"error": "缺少必需字段: provider_id, model_name",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
model := &domain.Model{
|
||||
ID: req.ID,
|
||||
ProviderID: req.ProviderID,
|
||||
ModelName: req.ModelName,
|
||||
}
|
||||
|
||||
err := h.modelService.Create(model)
|
||||
if err != nil {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "同一供应商下模型名称已存在",
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, model)
|
||||
c.JSON(http.StatusCreated, newModelResponse(model))
|
||||
}
|
||||
|
||||
// ListModels 列出模型
|
||||
@@ -68,7 +88,11 @@ func (h *ModelHandler) ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
resp := make([]modelResponse, len(models))
|
||||
for i, m := range models {
|
||||
resp[i] = newModelResponse(&m)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// GetModel 获取模型
|
||||
@@ -77,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
@@ -87,7 +111,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
c.JSON(http.StatusOK, newModelResponse(model))
|
||||
}
|
||||
|
||||
// UpdateModel 更新模型
|
||||
@@ -104,18 +128,25 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, appErrors.ErrModelNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": appErrors.ErrDuplicateModel.Message,
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
@@ -126,7 +157,7 @@ func (h *ModelHandler) UpdateModel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model)
|
||||
c.JSON(http.StatusOK, newModelResponse(model))
|
||||
}
|
||||
|
||||
// DeleteModel 删除模型
|
||||
@@ -135,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/protocol/openai"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// OpenAIHandler OpenAI 协议处理器
|
||||
type OpenAIHandler struct {
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
statsService service.StatsService
|
||||
}
|
||||
|
||||
// NewOpenAIHandler 创建 OpenAI 处理器
|
||||
func NewOpenAIHandler(client provider.ProviderClient, routingService service.RoutingService, statsService service.StatsService) *OpenAIHandler {
|
||||
return &OpenAIHandler{
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
statsService: statsService,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChatCompletions 处理 Chat Completions 请求
|
||||
func (h *OpenAIHandler) HandleChatCompletions(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "无效的请求格式: " + err.Error(),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 请求验证
|
||||
if validationErrors := openai.ValidateRequest(&req); validationErrors != nil {
|
||||
c.JSON(http.StatusBadRequest, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: formatValidationErrors(validationErrors),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
routeResult, err := h.routingService.Route(req.Model)
|
||||
if err != nil {
|
||||
h.handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStreamRequest(c, &req, routeResult)
|
||||
} else {
|
||||
h.handleNonStreamRequest(c, &req, routeResult)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleNonStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
resp, err := h.client.SendRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleStreamRequest(c *gin.Context, req *openai.ChatCompletionRequest, routeResult *domain.RouteResult) {
|
||||
eventChan, err := h.client.SendStreamRequest(c.Request.Context(), req, routeResult.Provider.APIKey, routeResult.Provider.BaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "供应商请求失败: " + err.Error(),
|
||||
Type: "api_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
|
||||
for event := range eventChan {
|
||||
if event.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if event.Done {
|
||||
writer.WriteString("data: [DONE]\n\n")
|
||||
writer.Flush()
|
||||
break
|
||||
}
|
||||
|
||||
writer.WriteString("data: ")
|
||||
writer.Write(event.Data)
|
||||
writer.WriteString("\n\n")
|
||||
writer.Flush()
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, req.Model)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *OpenAIHandler) handleError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
c.JSON(appErr.HTTPStatus, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: appErr.Message,
|
||||
Type: "invalid_request_error",
|
||||
Code: appErr.Code,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, openai.ErrorResponse{
|
||||
Error: openai.ErrorDetail{
|
||||
Message: "内部错误: " + err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// formatValidationErrors 将验证错误 map 格式化为字符串
|
||||
func formatValidationErrors(errors map[string]string) string {
|
||||
parts := make([]string, 0, len(errors))
|
||||
for field, msg := range errors {
|
||||
parts = append(parts, fmt.Sprintf("%s: %s", field, msg))
|
||||
}
|
||||
return "请求验证失败: " + strings.Join(parts, "; ")
|
||||
}
|
||||
@@ -1,16 +1,16 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
@@ -26,10 +26,11 @@ func NewProviderHandler(providerService service.ProviderService) *ProviderHandle
|
||||
// CreateProvider 创建供应商
|
||||
func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
var req struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Protocol string `json:"protocol"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -39,18 +40,25 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
protocol := req.Protocol
|
||||
if protocol == "" {
|
||||
protocol = "openai"
|
||||
}
|
||||
|
||||
provider := &domain.Provider{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
APIKey: req.APIKey,
|
||||
BaseURL: req.BaseURL,
|
||||
Protocol: protocol,
|
||||
}
|
||||
|
||||
err := h.providerService.Create(provider)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "供应商 ID 已存在",
|
||||
if errors.Is(err, appErrors.ErrInvalidProviderID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrInvalidProviderID.Message,
|
||||
"code": appErrors.ErrInvalidProviderID.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -58,7 +66,6 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
provider.MaskAPIKey()
|
||||
c.JSON(http.StatusCreated, provider)
|
||||
}
|
||||
|
||||
@@ -77,9 +84,9 @@ func (h *ProviderHandler) ListProviders(c *gin.Context) {
|
||||
func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
@@ -106,17 +113,24 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, appErrors.ErrImmutableField) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrImmutableField.Message,
|
||||
"code": appErrors.ErrImmutableField.Code,
|
||||
})
|
||||
return
|
||||
}
|
||||
writeError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.providerService.Get(id, true)
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
writeError(c, err)
|
||||
return
|
||||
@@ -131,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
|
||||
661
backend/internal/handler/proxy_handler.go
Normal file
661
backend/internal/handler/proxy_handler.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ProxyHandler 统一代理处理器
|
||||
type ProxyHandler struct {
|
||||
engine *conversion.ConversionEngine
|
||||
client provider.ProviderClient
|
||||
routingService service.RoutingService
|
||||
providerService service.ProviderService
|
||||
statsService service.StatsService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProxyHandler 创建统一代理处理器
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
engine: engine,
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
providerService: providerService,
|
||||
statsService: statsService,
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleProxy 处理代理请求
|
||||
func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||||
clientProtocol := c.Param("protocol")
|
||||
if clientProtocol == "" {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
|
||||
return
|
||||
}
|
||||
|
||||
// 原始路径: /{path}
|
||||
path := c.Param("path")
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
nativePath := path
|
||||
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
|
||||
|
||||
// 获取 client adapter
|
||||
registry := h.engine.GetRegistry()
|
||||
clientAdapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 检测接口类型
|
||||
ifaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
|
||||
// 处理 Models 接口:本地聚合
|
||||
if ifaceType == conversion.InterfaceTypeModels {
|
||||
h.handleModelsList(c, clientAdapter)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 ModelInfo 接口:本地查询
|
||||
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||||
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
|
||||
return
|
||||
}
|
||||
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||||
return
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建输入 HTTPRequestSpec
|
||||
inSpec := conversion.HTTPRequestSpec{
|
||||
URL: requestPath,
|
||||
Method: c.Request.Method,
|
||||
Headers: extractHeaders(c),
|
||||
Body: body,
|
||||
}
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
|
||||
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||
if err != nil {
|
||||
if isInvalidJSONError(err) {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
|
||||
return
|
||||
}
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
if providerID == "" || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
h.writeRouteError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 确定 providerProtocol
|
||||
providerProtocol := routeResult.Provider.Protocol
|
||||
if providerProtocol == "" {
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
// 构建 TargetProvider
|
||||
// 注意:ModelName 字段用于 Smart Passthrough 场景改写请求体
|
||||
// 同协议:请求体中的统一 ID 会被改写为 ModelName(上游名)
|
||||
// 跨协议:全量转换时 ModelName 会被编码到请求体中
|
||||
targetProvider := conversion.NewTargetProvider(
|
||||
routeResult.Provider.BaseURL,
|
||||
routeResult.Provider.APIKey,
|
||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||
)
|
||||
|
||||
// 计算统一模型 ID(用于响应覆写)
|
||||
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||||
|
||||
if isStream {
|
||||
h.handleStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||
} else {
|
||||
h.handleNonStream(c, inSpec, clientProtocol, providerProtocol, targetProvider, routeResult, unifiedModelID, ifaceType)
|
||||
}
|
||||
}
|
||||
|
||||
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isInvalidJSONError(err error) bool {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.logger.Error("转换请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.Error(err))
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||||
if err != nil {
|
||||
h.logger.Error("转换响应失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
// handleStream 处理流式请求
|
||||
func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
// flush 转换器
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
if _, err := writer.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStreamRequest 判断是否流式请求
|
||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ifaceType != conversion.InterfaceTypeChat {
|
||||
return false
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
}
|
||||
return req.Stream
|
||||
}
|
||||
|
||||
// handleModelsList 处理 GET /v1/models 本地聚合
|
||||
func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.ProtocolAdapter) {
|
||||
// 从数据库查询所有启用的模型
|
||||
models, err := h.providerService.ListEnabledModels()
|
||||
if err != nil {
|
||||
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建 CanonicalModelList
|
||||
modelList := &canonical.CanonicalModelList{
|
||||
Models: make([]canonical.CanonicalModel, 0, len(models)),
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
modelList.Models = append(modelList.Models, canonical.CanonicalModel{
|
||||
ID: m.UnifiedModelID(),
|
||||
Name: m.ModelName,
|
||||
Created: m.CreatedAt.Unix(),
|
||||
OwnedBy: m.ProviderID,
|
||||
})
|
||||
}
|
||||
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelsResponse(modelList)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// handleModelInfo 处理 GET /v1/models/{unified_id} 本地查询
|
||||
func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter conversion.ProtocolAdapter) {
|
||||
// 解析统一模型 ID
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库查询模型
|
||||
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建 CanonicalModelInfo
|
||||
modelInfo := &canonical.CanonicalModelInfo{
|
||||
ID: model.UnifiedModelID(),
|
||||
Name: model.ModelName,
|
||||
Created: model.CreatedAt.Unix(),
|
||||
OwnedBy: model.ProviderID,
|
||||
}
|
||||
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// writeConversionError 写入网关层转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
var convErr *conversion.ConversionError
|
||||
if errors.As(err, &convErr) {
|
||||
statusCode, code, message := mapConversionError(convErr)
|
||||
h.writeProxyError(c, statusCode, code, message)
|
||||
return
|
||||
}
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
|
||||
}
|
||||
|
||||
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
|
||||
switch err.Code {
|
||||
case conversion.ErrorCodeJSONParseError:
|
||||
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
|
||||
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
|
||||
}
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
case conversion.ErrorCodeInvalidInput,
|
||||
conversion.ErrorCodeMissingRequiredField,
|
||||
conversion.ErrorCodeProtocolConstraint:
|
||||
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
|
||||
case conversion.ErrorCodeInterfaceNotSupported:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
|
||||
case conversion.ErrorCodeUnsupportedMultimodal:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
|
||||
default:
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
switch appErr.Code {
|
||||
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
|
||||
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
|
||||
default:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
|
||||
h.logger.Error("上游不可达", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": message,
|
||||
"code": code,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range resp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range filterHopByHopHeaders(resp.Headers) {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
|
||||
registry := h.engine.GetRegistry()
|
||||
adapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil || len(providers) == 0 {
|
||||
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
|
||||
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
|
||||
return
|
||||
}
|
||||
|
||||
p := providers[0]
|
||||
providerProtocol := p.Protocol
|
||||
if providerProtocol == "" {
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||||
|
||||
var outSpec *conversion.HTTPRequestSpec
|
||||
if clientProtocol == providerProtocol {
|
||||
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
|
||||
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
|
||||
headers := adapter.BuildHeaders(targetProvider)
|
||||
if _, ok := headers["Content-Type"]; !ok {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
outSpec = &conversion.HTTPRequestSpec{
|
||||
URL: joinBaseURL(p.BaseURL, upstreamPath),
|
||||
Method: inSpec.Method,
|
||||
Headers: headers,
|
||||
Body: inSpec.Body,
|
||||
}
|
||||
} else {
|
||||
outSpec, err = h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if isStream {
|
||||
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, "")
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
|
||||
if err != nil {
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("透传流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripRawQuery(path string) string {
|
||||
pathOnly, _, _ := strings.Cut(path, "?")
|
||||
return pathOnly
|
||||
}
|
||||
|
||||
func rawQueryFromPath(path string) string {
|
||||
_, rawQuery, found := strings.Cut(path, "?")
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
return rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func headerValue(headers map[string]string, key string) string {
|
||||
for k, v := range headers {
|
||||
if strings.EqualFold(k, key) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func filterHopByHopHeaders(headers map[string]string) map[string]string {
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
hopByHop := map[string]struct{}{
|
||||
"connection": {},
|
||||
"transfer-encoding": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
filtered := make(map[string]string, len(headers))
|
||||
for k, v := range headers {
|
||||
if _, skip := hopByHop[strings.ToLower(k)]; skip {
|
||||
continue
|
||||
}
|
||||
filtered[k] = v
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// extractHeaders 从 Gin context 提取请求头
|
||||
func extractHeaders(c *gin.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for k, vs := range c.Request.Header {
|
||||
if len(vs) > 0 {
|
||||
headers[k] = vs[0]
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
1453
backend/internal/handler/proxy_handler_test.go
Normal file
1453
backend/internal/handler/proxy_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user