Compare commits
46 Commits
5b765c8b5e
...
v0.1.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 235efb0e62 | |||
| 6b1af27ea2 | |||
| 32f48777f3 | |||
| bc7a7c6e81 | |||
| 3cd0458c2c | |||
| 8eea30ea11 | |||
| 9e33e570af | |||
| 7653385838 | |||
| 2c401f7ae6 | |||
| a9972360c2 | |||
| b00fa4dcee | |||
| 92525b39c3 | |||
| 38a2555c7b | |||
| 9622d44aac | |||
| 155244433f | |||
| 2c043c6cf7 | |||
| f5c82b6980 | |||
| 9105a36097 | |||
| f1ee646ca4 | |||
| b9b487c591 | |||
| 4c62c071fb | |||
| b2e9dd8b7f | |||
| d143c5f3df | |||
| 4eebdfb8db | |||
| b517946585 | |||
| 4ddae6be74 | |||
| 195762ff97 | |||
| bcf5ca89e5 | |||
| 365943e4c4 | |||
| 4c6b49099d | |||
| 4c78ab6cc8 | |||
| 52007c9461 | |||
| 086dd1fed7 | |||
| 1d7e839b49 | |||
| fa7babf13b | |||
| 280099b89c | |||
| 0a92a25451 | |||
| 8c075194e5 | |||
| 53e477d383 | |||
| 1522c87c74 | |||
| e0d05c9869 | |||
| 5b401e29cb | |||
| 65ac9f740a | |||
| 58ebcaa299 | |||
| b3258e76df | |||
| 64dc66afa6 |
7
.editorconfig
Normal file
7
.editorconfig
Normal file
@@ -0,0 +1,7 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = lf
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
9
.gitattributes
vendored
Normal file
9
.gitattributes
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
* text=auto eol=lf
|
||||
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
assets/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.icns filter=lfs diff=lfs merge=lfs -text
|
||||
assets/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||
assets/**/*.ico filter=lfs diff=lfs merge=lfs -text
|
||||
frontend/public/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
frontend/public/**/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
210
.github/workflows/release.yml
vendored
Normal file
210
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,210 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
prepare:
|
||||
name: Prepare Release
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Verify tag and VERSION
|
||||
id: version
|
||||
run: |
|
||||
version=$(go run ./versionctl print)
|
||||
go run ./versionctl verify-tag "${GITHUB_REF_NAME}"
|
||||
printf 'version=%s\n' "$version" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-linux:
|
||||
name: Build Linux Assets
|
||||
needs: prepare
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Install Linux desktop build dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libayatana-appindicator3-dev libgtk-3-dev
|
||||
|
||||
- name: Preflight Linux release toolchain
|
||||
run: |
|
||||
set -euo pipefail
|
||||
command -v go
|
||||
go version
|
||||
command -v bun
|
||||
bun --version
|
||||
command -v gcc
|
||||
gcc --version
|
||||
command -v pkg-config
|
||||
pkg-config --modversion ayatana-appindicator3-0.1
|
||||
pkg-config --modversion gtk+-3.0
|
||||
|
||||
- name: Build Linux release assets
|
||||
run: make release-assets-linux
|
||||
|
||||
- name: Upload Linux release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-linux
|
||||
path: build/release/*
|
||||
|
||||
build-windows:
|
||||
name: Build Windows Assets
|
||||
needs: prepare
|
||||
runs-on: windows-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Setup MSYS2 toolchain
|
||||
uses: msys2/setup-msys2@v2
|
||||
with:
|
||||
msystem: MINGW64
|
||||
path-type: inherit
|
||||
update: true
|
||||
install: >-
|
||||
make
|
||||
mingw-w64-x86_64-gcc
|
||||
|
||||
- name: Preflight Windows release toolchain
|
||||
shell: msys2 {0}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
command -v go
|
||||
go version
|
||||
command -v bun
|
||||
bun --version
|
||||
command -v make
|
||||
make --version
|
||||
command -v gcc
|
||||
gcc --version
|
||||
command -v windres
|
||||
windres --version
|
||||
if command -v powershell.exe >/dev/null 2>&1; then
|
||||
powershell.exe -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
|
||||
else
|
||||
command -v powershell
|
||||
powershell -NoProfile -Command '$PSVersionTable.PSVersion.ToString()'
|
||||
fi
|
||||
|
||||
- name: Build Windows release assets
|
||||
shell: msys2 {0}
|
||||
run: make release-assets-windows
|
||||
|
||||
- name: Upload Windows release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-windows
|
||||
path: build/release/*
|
||||
|
||||
build-macos:
|
||||
name: Build macOS Assets
|
||||
needs: prepare
|
||||
runs-on: macos-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.work
|
||||
cache-dependency-path: |
|
||||
backend/go.sum
|
||||
versionctl/go.sum
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Preflight macOS release toolchain
|
||||
run: |
|
||||
set -euo pipefail
|
||||
command -v go
|
||||
go version
|
||||
command -v bun
|
||||
bun --version
|
||||
command -v ditto
|
||||
xcrun --find lipo
|
||||
xcrun --find vtool
|
||||
|
||||
- name: Build macOS release assets
|
||||
run: make release-assets-macos
|
||||
|
||||
- name: Upload macOS release assets
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-macos
|
||||
path: build/release/*
|
||||
|
||||
draft-release:
|
||||
name: Create Draft Release
|
||||
needs: [prepare, build-linux, build-windows, build-macos]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Download release assets
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: release-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
|
||||
- name: Publish draft release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
name: v${{ needs.prepare.outputs.version }}
|
||||
tag_name: ${{ github.ref_name }}
|
||||
draft: true
|
||||
files: |
|
||||
dist/*
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -401,13 +401,16 @@ cython_debug/
|
||||
# Custom
|
||||
.claude
|
||||
.opencode
|
||||
.codex
|
||||
openspec/changes/archive
|
||||
temp
|
||||
.agents
|
||||
skills-lock.json
|
||||
.worktrees
|
||||
!scripts/build/
|
||||
backend/bin
|
||||
|
||||
# Embedfs generated
|
||||
embedfs/assets/
|
||||
embedfs/frontend-dist/
|
||||
backend/cmd/desktop/rsrc_windows_*.syso
|
||||
|
||||
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"files.eol": "\n"
|
||||
}
|
||||
184
LICENSE
Normal file
184
LICENSE
Normal file
@@ -0,0 +1,184 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, and
|
||||
distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the copyright
|
||||
owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other entities
|
||||
that control, are controlled by, or are under common control with that entity.
|
||||
For the purposes of this definition, "control" means (i) the power, direct or
|
||||
indirect, to cause the direction or management of such entity, whether by
|
||||
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising
|
||||
permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications, including
|
||||
but not limited to software source code, documentation source, and configuration
|
||||
files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation or
|
||||
translation of a Source form, including but not limited to compiled object code,
|
||||
generated documentation, and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object form,
|
||||
made available under the License, as indicated by a copyright notice that is
|
||||
included in or attached to the work (an example is provided in the Appendix
|
||||
below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form, that
|
||||
is based on (or derived from) the Work and for which the editorial revisions,
|
||||
annotations, elaborations, or other modifications represent, as a whole, an
|
||||
original work of authorship. For the purposes of this License, Derivative Works
|
||||
shall not include works that remain separable from, or merely link (or bind by
|
||||
name) to the interfaces of, the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original version
|
||||
of the Work and any modifications or additions to that Work or Derivative Works
|
||||
thereof, that is intentionally submitted to Licensor for inclusion in the Work
|
||||
by the copyright owner or by an individual or Legal Entity authorized to submit
|
||||
on behalf of the copyright owner. For the purposes of this definition,
|
||||
"submitted" means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems, and
|
||||
issue tracking systems that are managed by, or on behalf of, the Licensor for
|
||||
the purpose of discussing and improving the Work, but excluding communication
|
||||
that is conspicuously marked or otherwise designated in writing by the copyright
|
||||
owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
|
||||
of whom a Contribution has been received by Licensor and subsequently
|
||||
incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this
|
||||
License, each Contributor hereby grants to You a perpetual, worldwide,
|
||||
non-exclusive, no-charge, royalty-free, irrevocable copyright license to
|
||||
reproduce, prepare Derivative Works of, publicly display, publicly perform,
|
||||
sublicense, and distribute the Work and such Derivative Works in Source or
|
||||
Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License,
|
||||
each Contributor hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section) patent
|
||||
license to make, have made, use, offer to sell, sell, import, and otherwise
|
||||
transfer the Work, where such license applies only to those patent claims
|
||||
licensable by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s) with the Work
|
||||
to which such Contribution(s) was submitted. If You institute patent litigation
|
||||
against any entity (including a cross-claim or counterclaim in a lawsuit)
|
||||
alleging that the Work or a Contribution incorporated within the Work
|
||||
constitutes direct or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate as of the date
|
||||
such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or
|
||||
Derivative Works thereof in any medium, with or without modifications, and in
|
||||
Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of
|
||||
this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices stating that
|
||||
You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works that You
|
||||
distribute, all copyright, patent, trademark, and attribution notices from the
|
||||
Source form of the Work, excluding those notices that do not pertain to any part
|
||||
of the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then
|
||||
any Derivative Works that You distribute must include a readable copy of the
|
||||
attribution notices contained within such NOTICE file, excluding those notices
|
||||
that do not pertain to any part of the Derivative Works, in at least one of the
|
||||
following places: within a NOTICE text file distributed as part of the
|
||||
Derivative Works; within the Source form or documentation, if provided along
|
||||
with the Derivative Works; or, within a display generated by the Derivative
|
||||
Works, if and wherever such third-party notices normally appear. The contents of
|
||||
the NOTICE file are for informational purposes only and do not modify the
|
||||
License. You may add Your own attribution notices within Derivative Works that
|
||||
You distribute, alongside or as an addendum to the NOTICE text from the Work,
|
||||
provided that such additional attribution notices cannot be construed as
|
||||
modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide
|
||||
additional or different license terms and conditions for use, reproduction, or
|
||||
distribution of Your modifications, or for any such Derivative Works as a whole,
|
||||
provided Your use, reproduction, and distribution of the Work otherwise complies
|
||||
with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any
|
||||
Contribution intentionally submitted for inclusion in the Work by You to the
|
||||
Licensor shall be under the terms and conditions of this License, without any
|
||||
additional terms or conditions. Notwithstanding the above, nothing herein shall
|
||||
supersede or modify the terms of any separate license agreement you may have
|
||||
executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade names,
|
||||
trademarks, service marks, or product names of the Licensor, except as required
|
||||
for reasonable and customary use in describing the origin of the Work and
|
||||
reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in
|
||||
writing, Licensor provides the Work (and each Contributor provides its
|
||||
Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied, including, without limitation, any warranties
|
||||
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any risks
|
||||
associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, whether in
|
||||
tort (including negligence), contract, or otherwise, unless required by
|
||||
applicable law (such as deliberate and grossly negligent acts) or agreed to in
|
||||
writing, shall any Contributor be liable to You for damages, including any
|
||||
direct, indirect, special, incidental, or consequential damages of any character
|
||||
arising as a result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill, work stoppage,
|
||||
computer failure or malfunction, or any and all other commercial damages or
|
||||
losses), even if such Contributor has been advised of the possibility of such
|
||||
damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or
|
||||
Derivative Works thereof, You may choose to offer, and charge a fee for,
|
||||
acceptance of support, warranty, indemnity, or other liability obligations
|
||||
and/or rights consistent with this License. However, in accepting such
|
||||
obligations, You may act only on Your own behalf and on Your sole
|
||||
responsibility, not on behalf of any other Contributor, and only if You agree to
|
||||
indemnify, defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason of your
|
||||
accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following boilerplate
|
||||
notice, with the fields enclosed by brackets "[]" replaced with your own
|
||||
identifying information. (Don't include the brackets!) The text should be
|
||||
enclosed in the appropriate comment syntax for the file format. We also
|
||||
recommend that a file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier identification within
|
||||
third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
318
Makefile
318
Makefile
@@ -1,119 +1,251 @@
|
||||
.PHONY: all clean \
|
||||
backend-build backend-run backend-test backend-test-unit backend-test-integration backend-test-coverage \
|
||||
backend-lint backend-deps backend-generate \
|
||||
backend-migrate-up backend-migrate-down backend-migrate-status backend-migrate-create \
|
||||
frontend-build frontend-dev frontend-test frontend-test-watch frontend-test-coverage frontend-test-e2e frontend-lint \
|
||||
desktop desktop-darwin desktop-windows desktop-linux package-macos
|
||||
.PHONY: \
|
||||
lint test clean \
|
||||
version-sync version-check version-bump \
|
||||
server-run server-build server-lint server-test server-clean \
|
||||
desktop-build-mac desktop-build-win desktop-build-linux \
|
||||
desktop-lint desktop-test desktop-clean \
|
||||
release-assets-linux release-assets-windows release-assets-macos \
|
||||
_backend-lint _backend-test _backend-clean _backend-build \
|
||||
_versionctl-lint _versionctl-test \
|
||||
_frontend-install _frontend-build _frontend-check _frontend-test _frontend-dev _frontend-clean \
|
||||
_desktop-test _desktop-clean _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource \
|
||||
_server-run-backend _server-run-frontend
|
||||
|
||||
# Delay shell lookups until a target needs them, then cache the result for this make run.
|
||||
lazy_shell = $(or $($(1)),$(eval $(1) := $(shell $(2)))$($(1)))
|
||||
|
||||
VERSION = $(call lazy_shell,_VERSION,go run ./versionctl print)
|
||||
GIT_COMMIT ?= $(call lazy_shell,_GIT_COMMIT,git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||
BUILD_TIME ?= $(call lazy_shell,_BUILD_TIME,date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
GO_LDFLAGS = -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||
GO_LDFLAGS_WIN = $(GO_LDFLAGS) -H=windowsgui
|
||||
RELEASE_DIR := build/release
|
||||
SERVER_LINUX_ASSET = $(call lazy_shell,_SERVER_LINUX_ASSET,go run ./versionctl asset-name server linux amd64)
|
||||
SERVER_WINDOWS_ASSET = $(call lazy_shell,_SERVER_WINDOWS_ASSET,go run ./versionctl asset-name server windows amd64)
|
||||
SERVER_DARWIN_AMD64_ASSET = $(call lazy_shell,_SERVER_DARWIN_AMD64_ASSET,go run ./versionctl asset-name server darwin amd64)
|
||||
SERVER_DARWIN_ARM64_ASSET = $(call lazy_shell,_SERVER_DARWIN_ARM64_ASSET,go run ./versionctl asset-name server darwin arm64)
|
||||
DESKTOP_LINUX_ASSET = $(call lazy_shell,_DESKTOP_LINUX_ASSET,go run ./versionctl asset-name desktop linux)
|
||||
DESKTOP_WINDOWS_ASSET = $(call lazy_shell,_DESKTOP_WINDOWS_ASSET,go run ./versionctl asset-name desktop windows)
|
||||
DESKTOP_MACOS_ASSET = $(call lazy_shell,_DESKTOP_MACOS_ASSET,go run ./versionctl asset-name desktop macos)
|
||||
|
||||
# ============================================
|
||||
# 后端
|
||||
# 全局命令
|
||||
# ============================================
|
||||
|
||||
all: backend-build
|
||||
lint: _backend-lint _frontend-check _versionctl-lint
|
||||
@printf 'Lint complete\n'
|
||||
|
||||
backend-build:
|
||||
cd backend && go build -o bin/server ./cmd/server
|
||||
test: _backend-test _frontend-test _desktop-test _versionctl-test
|
||||
@printf 'All tests passed\n'
|
||||
|
||||
backend-run:
|
||||
cd backend && go run ./cmd/server
|
||||
|
||||
backend-test:
|
||||
cd backend && go test ./... -v
|
||||
|
||||
backend-test-unit:
|
||||
cd backend && go test ./internal/... ./pkg/... -v
|
||||
|
||||
backend-test-integration:
|
||||
cd backend && go test ./tests/... -v
|
||||
|
||||
backend-test-coverage:
|
||||
cd backend && go test ./... -coverprofile=coverage.out
|
||||
cd backend && go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: backend/coverage.html"
|
||||
|
||||
backend-lint:
|
||||
cd backend && go tool golangci-lint run ./...
|
||||
|
||||
backend-deps:
|
||||
cd backend && go mod tidy
|
||||
|
||||
backend-generate:
|
||||
cd backend && go generate ./...
|
||||
|
||||
DB_DRIVER ?= sqlite3
|
||||
DB_DSN ?= $(DB_PATH)
|
||||
|
||||
backend-migrate-up:
|
||||
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" up
|
||||
|
||||
backend-migrate-down:
|
||||
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" down
|
||||
|
||||
backend-migrate-status:
|
||||
cd backend && goose -dir migrations/$(DB_DRIVER) $(DB_DRIVER) "$(DB_DSN)" status
|
||||
|
||||
backend-migrate-create:
|
||||
@read -p "Migration name: " name; \
|
||||
cd backend && goose -dir migrations/sqlite create $$name sql; \
|
||||
cd backend && goose -dir migrations/mysql create $$name sql
|
||||
clean: _backend-clean _frontend-clean _desktop-clean
|
||||
@printf 'Clean complete\n'
|
||||
|
||||
# ============================================
|
||||
# 前端
|
||||
# 版本管理
|
||||
# ============================================
|
||||
|
||||
frontend-build:
|
||||
cd frontend && bun install && bun run build
|
||||
version-sync:
|
||||
go run ./versionctl sync
|
||||
|
||||
frontend-dev:
|
||||
cd frontend && bun dev
|
||||
version-check:
|
||||
go run ./versionctl check
|
||||
|
||||
frontend-test:
|
||||
cd frontend && bun run test
|
||||
|
||||
frontend-test-watch:
|
||||
cd frontend && bun run test:watch
|
||||
|
||||
frontend-test-coverage:
|
||||
cd frontend && bun run test:coverage
|
||||
|
||||
frontend-test-e2e:
|
||||
cd frontend && bun run test:e2e
|
||||
|
||||
frontend-lint:
|
||||
cd frontend && bun run lint
|
||||
version-bump: BUMP ?= patch
|
||||
version-bump:
|
||||
$(eval _BUMP_ARG := $(if $(SET_VERSION),$(SET_VERSION),$(BUMP)))
|
||||
$(eval _NEW_VERSION := $(shell go run ./versionctl bump $(_BUMP_ARG)))
|
||||
git add VERSION frontend/
|
||||
git commit -m "chore: 版本升迁 v$(_NEW_VERSION)"
|
||||
git tag "v$(_NEW_VERSION)"
|
||||
@printf '版本升迁完成: v%s\n' "$(_NEW_VERSION)"
|
||||
|
||||
# ============================================
|
||||
# 桌面应用
|
||||
# Server 模式
|
||||
# ============================================
|
||||
|
||||
desktop: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 go build -o ../build/nex ./cmd/desktop
|
||||
server-run:
|
||||
@$(MAKE) -j2 _server-run-backend _server-run-frontend
|
||||
|
||||
frontend-build-desktop:
|
||||
cd frontend && cp .env.desktop .env.production.local && bun install && bun run build && rm -f .env.production.local
|
||||
server-build: version-check _backend-build _frontend-build
|
||||
@printf 'Server build complete\n'
|
||||
|
||||
embedfs-prepare:
|
||||
server-lint: _backend-lint _frontend-check
|
||||
@printf 'Server lint complete\n'
|
||||
|
||||
server-test: _backend-test _frontend-test
|
||||
@printf 'Server tests passed\n'
|
||||
|
||||
server-clean: _backend-clean _frontend-clean
|
||||
@printf 'Server artifacts cleaned\n'
|
||||
|
||||
_server-run-backend:
|
||||
@$(MAKE) -C backend run
|
||||
|
||||
_server-run-frontend: _frontend-install
|
||||
cd frontend && bun run dev
|
||||
|
||||
# ============================================
|
||||
# Desktop 模式
|
||||
# ============================================
|
||||
|
||||
desktop-build-mac: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
|
||||
@printf 'Building macOS desktop...\n'
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-arm64 ./cmd/desktop
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-mac-amd64 ./cmd/desktop
|
||||
lipo -create build/nex-mac-arm64 build/nex-mac-amd64 -output build/nex-mac-universal
|
||||
@printf 'Packaging macOS app bundle...\n'
|
||||
mkdir -p build/Nex.app/Contents/MacOS build/Nex.app/Contents/Resources
|
||||
cp build/nex-mac-universal build/Nex.app/Contents/MacOS/nex
|
||||
@if [ -f assets/icon.icns ]; then \
|
||||
cp assets/icon.icns build/Nex.app/Contents/Resources/; \
|
||||
else \
|
||||
printf 'Missing assets/icon.icns\n'; \
|
||||
fi
|
||||
@MIN_MACOS_VERSION=$$(vtool -show-build build/nex-mac-universal | awk '/minos / {print $$2; exit}'); \
|
||||
if [ -z "$$MIN_MACOS_VERSION" ]; then \
|
||||
printf 'Unable to read macOS minimum version\n'; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
go run ./versionctl macos-plist "$$MIN_MACOS_VERSION" > build/Nex.app/Contents/Info.plist
|
||||
chmod +x build/Nex.app/Contents/MacOS/nex
|
||||
@printf 'macOS desktop build complete\n'
|
||||
|
||||
desktop-build-win: version-check _desktop-prepare-frontend _desktop-prepare-embedfs _desktop-prepare-windows-resource
|
||||
@printf 'Building Windows desktop...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "New-Item -ItemType Directory -Path 'build' -Force | Out-Null"
|
||||
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
|
||||
else
|
||||
mkdir -p build
|
||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-win-amd64.exe ./cmd/desktop
|
||||
endif
|
||||
@printf 'Windows desktop build complete\n'
|
||||
|
||||
desktop-build-linux: version-check _desktop-prepare-frontend _desktop-prepare-embedfs
|
||||
@printf 'Building Linux desktop...\n'
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-linux-amd64 ./cmd/desktop
|
||||
@printf 'Linux desktop build complete\n'
|
||||
|
||||
desktop-lint: _backend-lint _frontend-check
|
||||
@printf 'Desktop lint complete\n'
|
||||
|
||||
desktop-test: _desktop-test
|
||||
@printf 'Desktop tests passed\n'
|
||||
|
||||
desktop-clean: _desktop-clean
|
||||
@printf 'Desktop artifacts cleaned\n'
|
||||
|
||||
_desktop-test:
|
||||
cd backend && go test ./cmd/desktop/... -v
|
||||
|
||||
_desktop-clean:
|
||||
rm -rf build/ embedfs/assets embedfs/frontend-dist backend/cmd/desktop/rsrc_windows_amd64.syso
|
||||
|
||||
_desktop-prepare-frontend: _frontend-install
|
||||
@printf 'Preparing frontend for desktop...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "Copy-Item -LiteralPath 'frontend/.env.desktop' -Destination 'frontend/.env.production.local' -Force"
|
||||
cd frontend && bun run build
|
||||
powershell -NoProfile -Command "Remove-Item -LiteralPath 'frontend/.env.production.local' -Force -ErrorAction SilentlyContinue"
|
||||
else
|
||||
cd frontend && cp .env.desktop .env.production.local
|
||||
cd frontend && bun run build
|
||||
rm -f frontend/.env.production.local
|
||||
endif
|
||||
|
||||
_desktop-prepare-embedfs:
|
||||
@printf 'Preparing embedded filesystem...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "Remove-Item -LiteralPath 'embedfs/assets' -Recurse -Force -ErrorAction SilentlyContinue; Remove-Item -LiteralPath 'embedfs/frontend-dist' -Recurse -Force -ErrorAction SilentlyContinue; Copy-Item -LiteralPath 'assets' -Destination 'embedfs/assets' -Recurse; Copy-Item -LiteralPath 'frontend/dist' -Destination 'embedfs/frontend-dist' -Recurse"
|
||||
else
|
||||
rm -rf embedfs/assets embedfs/frontend-dist
|
||||
cp -r assets embedfs/assets
|
||||
cp -r frontend/dist embedfs/frontend-dist
|
||||
endif
|
||||
|
||||
desktop-darwin: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -o ../build/nex-darwin-arm64 ./cmd/desktop
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -o ../build/nex-darwin-amd64 ./cmd/desktop
|
||||
|
||||
desktop-windows: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 GOOS=windows GOARCH=amd64 go build -ldflags "-H=windowsgui" -o ../build/nex-windows-amd64.exe ./cmd/desktop
|
||||
|
||||
desktop-linux: frontend-build-desktop embedfs-prepare
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o ../build/nex-linux-amd64 ./cmd/desktop
|
||||
|
||||
package-macos:
|
||||
./scripts/build/package-macos.sh
|
||||
_desktop-prepare-windows-resource:
|
||||
@printf 'Preparing Windows executable icon...\n'
|
||||
ifeq ($(OS),Windows_NT)
|
||||
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso
|
||||
else
|
||||
@if command -v x86_64-w64-mingw32-windres >/dev/null 2>&1; then \
|
||||
cd backend/cmd/desktop && x86_64-w64-mingw32-windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
|
||||
elif command -v windres >/dev/null 2>&1; then \
|
||||
cd backend/cmd/desktop && windres -O coff -F pe-x86-64 -i icon_windows.rc -o rsrc_windows_amd64.syso; \
|
||||
else \
|
||||
printf 'Missing windres for Windows icon resource generation\n'; \
|
||||
exit 1; \
|
||||
fi
|
||||
endif
|
||||
|
||||
# ============================================
|
||||
# 清理
|
||||
# 发布资产
|
||||
# ============================================
|
||||
|
||||
clean:
|
||||
rm -rf backend/bin/ backend/coverage.out backend/coverage.html
|
||||
rm -rf build/
|
||||
release-assets-linux: version-check desktop-build-linux
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
cd backend && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-linux-amd64 ./cmd/server
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_LINUX_ASSET)" nex-server-linux-amd64
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(DESKTOP_LINUX_ASSET)" nex-linux-amd64
|
||||
|
||||
release-assets-windows: version-check desktop-build-win
|
||||
ifeq ($(OS),Windows_NT)
|
||||
powershell -NoProfile -Command "Remove-Item -LiteralPath '$(RELEASE_DIR)' -Recurse -Force -ErrorAction SilentlyContinue; New-Item -ItemType Directory -Path '$(RELEASE_DIR)' -Force | Out-Null"
|
||||
cd backend && set "CGO_ENABLED=1"&& set "GOOS=windows"&& set "GOARCH=amd64"&& go build -ldflags "$(GO_LDFLAGS_WIN)" -o ../build/nex-server-win-amd64.exe ./cmd/server
|
||||
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-server-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(SERVER_WINDOWS_ASSET)' -Force"
|
||||
powershell -NoProfile -Command "Compress-Archive -LiteralPath 'build/nex-win-amd64.exe' -DestinationPath '$(RELEASE_DIR)/$(DESKTOP_WINDOWS_ASSET)' -Force"
|
||||
else
|
||||
@printf 'release-assets-windows requires Windows\n'
|
||||
@exit 1
|
||||
endif
|
||||
|
||||
release-assets-macos: version-check desktop-build-mac
|
||||
rm -rf "$(RELEASE_DIR)"
|
||||
mkdir -p "$(RELEASE_DIR)"
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=amd64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-amd64 ./cmd/server
|
||||
cd backend && CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 go build -ldflags "$(GO_LDFLAGS)" -o ../build/nex-server-darwin-arm64 ./cmd/server
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_AMD64_ASSET)" nex-server-darwin-amd64
|
||||
tar -C build -czf "$(RELEASE_DIR)/$(SERVER_DARWIN_ARM64_ASSET)" nex-server-darwin-arm64
|
||||
ditto -c -k --keepParent build/Nex.app "$(RELEASE_DIR)/$(DESKTOP_MACOS_ASSET)"
|
||||
|
||||
# ============================================
|
||||
# 共享 helper targets
|
||||
# ============================================
|
||||
|
||||
_backend-build:
|
||||
@$(MAKE) -C backend build
|
||||
|
||||
_backend-lint:
|
||||
@$(MAKE) -C backend lint
|
||||
|
||||
_backend-test:
|
||||
@$(MAKE) -C backend test
|
||||
|
||||
_backend-clean:
|
||||
@$(MAKE) -C backend clean
|
||||
|
||||
_versionctl-lint:
|
||||
@$(MAKE) -C versionctl lint
|
||||
|
||||
_versionctl-test:
|
||||
@$(MAKE) -C versionctl test
|
||||
|
||||
_frontend-install:
|
||||
cd frontend && bun install
|
||||
|
||||
_frontend-build: _frontend-install
|
||||
cd frontend && bun run build
|
||||
|
||||
_frontend-check: _frontend-install
|
||||
cd frontend && bun run check
|
||||
|
||||
_frontend-test: _frontend-install
|
||||
cd frontend && bun run test
|
||||
|
||||
_frontend-dev: _frontend-install
|
||||
cd frontend && bun run dev
|
||||
|
||||
_frontend-clean:
|
||||
rm -rf frontend/dist frontend/.next frontend/coverage frontend/playwright-report frontend/test-results frontend/tsconfig.tsbuildinfo
|
||||
|
||||
184
README.md
184
README.md
@@ -36,13 +36,9 @@ nex/
|
||||
│
|
||||
├── assets/ # 应用资源
|
||||
│ ├── icon.png # 托盘图标
|
||||
│ ├── AppIcon.icns # macOS 应用图标
|
||||
│ ├── icon.icns # macOS 应用图标
|
||||
│ └── icon.ico # Windows 应用图标
|
||||
│
|
||||
├── scripts/ # 构建脚本
|
||||
│ └── build/
|
||||
│ └── package-macos.sh # macOS .app 打包脚本
|
||||
│
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
@@ -51,7 +47,7 @@ nex/
|
||||
- **双协议支持**:同时支持 OpenAI 和 Anthropic 协议
|
||||
- **跨协议转换**:Hub-and-Spoke 架构实现 OpenAI ↔ Anthropic 双向转换
|
||||
- **统一模型 ID**:`provider_id/model_name` 格式全局唯一标识模型(如 `openai/gpt-4`)
|
||||
- **Smart Passthrough**:同协议请求零序列化开销,仅改写 model 字段
|
||||
- **Smart Passthrough**:同协议请求跳过 Canonical 全量转换,仅在 JSON 层改写 model 字段
|
||||
- **流式响应**:完整支持 SSE 流式传输,包括跨协议流式转换
|
||||
- **Function Calling**:支持工具调用(Tools)
|
||||
- **Thinking / Reasoning**:支持 OpenAI `reasoning_effort` 和 Anthropic `thinking` 配置
|
||||
@@ -67,11 +63,25 @@ nex/
|
||||
- **HTTP 框架**: Gin
|
||||
- **ORM**: GORM
|
||||
- **数据库**: SQLite / MySQL
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转)
|
||||
- **日志**: zap + lumberjack(结构化日志 + 日志轮转 + 模块标识)
|
||||
- **配置**: Viper + pflag(多层配置:CLI > 环境变量 > 配置文件 > 默认值)
|
||||
- **验证**: go-playground/validator/v10
|
||||
- **迁移**: goose
|
||||
|
||||
#### 日志模块标识规范
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger,日志输出格式为 `[module.name]`:
|
||||
|
||||
```
|
||||
Console: INFO [handler.proxy] 处理请求 method=POST path=/v1/chat
|
||||
JSON: {"level":"info","logger":"handler.proxy","msg":"处理请求","method":"POST"}
|
||||
```
|
||||
|
||||
模块命名规范:
|
||||
- 单一职责包:`database`、`config`
|
||||
- 多实体包:`handler.proxy`、`service.provider`
|
||||
- 子包:`handler.middleware`
|
||||
|
||||
### 前端
|
||||
- **运行时**: Bun
|
||||
- **构建工具**: Vite
|
||||
@@ -81,7 +91,7 @@ nex/
|
||||
- **图表库**: Recharts
|
||||
- **路由**: React Router v7
|
||||
- **数据获取**: TanStack Query v5
|
||||
- **样式**: SCSS Modules
|
||||
- **样式**: TDesign 组件 props 优先,TDesign tokens 次之,SCSS 作为兜底补充
|
||||
- **测试**: Vitest + React Testing Library + Playwright
|
||||
|
||||
## 快速开始
|
||||
@@ -91,22 +101,18 @@ nex/
|
||||
**构建桌面应用**:
|
||||
|
||||
```bash
|
||||
# 当前平台
|
||||
make desktop
|
||||
|
||||
# macOS (arm64 + amd64)
|
||||
make desktop-darwin
|
||||
make package-macos # 打包为 .app
|
||||
# macOS (arm64 + amd64,并打包为 .app)
|
||||
make desktop-build-mac
|
||||
|
||||
# Windows
|
||||
make desktop-windows
|
||||
make desktop-build-win
|
||||
|
||||
# Linux
|
||||
make desktop-linux
|
||||
make desktop-build-linux
|
||||
```
|
||||
|
||||
**使用桌面应用**:
|
||||
- 双击启动应用(macOS: Nex.app,Windows: nex.exe,Linux: nex)
|
||||
- 双击启动应用(macOS: Nex.app,Windows: nex-win-amd64.exe,Linux: nex-linux-amd64)
|
||||
- 系统托盘图标出现,浏览器自动打开管理界面
|
||||
- 点击托盘图标显示菜单,可打开管理界面或退出
|
||||
- 关闭浏览器后服务继续运行,可通过托盘重新打开
|
||||
@@ -123,50 +129,54 @@ make desktop-linux
|
||||
- Xfce: 需要 libappindicator
|
||||
- 其他支持 StatusNotifierItem 规范的环境
|
||||
|
||||
### CLI 模式
|
||||
|
||||
#### 后端
|
||||
### Server 模式(前后端分离)
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
go mod download
|
||||
go run cmd/server/main.go
|
||||
make server-run
|
||||
```
|
||||
|
||||
后端服务将在 `http://localhost:9826` 启动。首次启动会自动:
|
||||
`make server-run` 会并行启动:
|
||||
- 后端服务:`http://localhost:9826`
|
||||
- 前端开发服务器:`http://localhost:5173`
|
||||
|
||||
前端请求会继续通过 Vite proxy 转发到后端。后端首次启动会自动:
|
||||
- 创建配置文件 `~/.nex/config.yaml`
|
||||
- 初始化数据库 `~/.nex/config.db`
|
||||
- 运行数据库迁移
|
||||
- 创建日志目录 `~/.nex/log/`
|
||||
|
||||
### 前端
|
||||
**构建 server 模式产物**:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
bun install
|
||||
bun dev
|
||||
make server-build
|
||||
```
|
||||
|
||||
前端开发服务器将在 `http://localhost:5173` 启动,API 请求通过 Vite proxy 转发到后端。
|
||||
|
||||
## API 接口
|
||||
|
||||
### 代理接口(对外部应用)
|
||||
|
||||
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写保持参数保真;跨协议请求走完整 decode/encode 转换。
|
||||
代理接口统一使用 `/{protocol}/*path` 路由格式,模型 ID 使用 `provider_id/model_name` 格式(如 `openai/gpt-4`)。同协议请求走 Smart Passthrough,最小化 JSON 改写并保持未改写字段的 JSON 内容和类型不变;跨协议请求走完整 decode/encode 转换。
|
||||
|
||||
**OpenAI 协议**(`protocol=openai`):
|
||||
- `POST /openai/chat/completions` - 对话补全
|
||||
- `GET /openai/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /openai/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
- `POST /openai/embeddings` - 嵌入
|
||||
- `POST /openai/rerank` - 重排序
|
||||
- `POST /openai/v1/chat/completions` - 对话补全
|
||||
- `GET /openai/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /openai/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
- `POST /openai/v1/embeddings` - 嵌入
|
||||
- `POST /openai/v1/rerank` - 重排序
|
||||
|
||||
**Anthropic 协议**(`protocol=anthropic`):
|
||||
- `POST /anthropic/v1/messages` - 消息对话
|
||||
- `GET /anthropic/v1/models` - 模型列表(本地数据库聚合)
|
||||
- `GET /anthropic/v1/models/{provider_id}/{model_name}` - 模型详情(本地数据库查询)
|
||||
|
||||
路径边界:网关只剥离第一段协议前缀,剩余路径保持协议原生形态交给 adapter。OpenAI adapter 接收 `/v1/chat/completions`、`/v1/models`、`/v1/embeddings`、`/v1/rerank`,并在构建上游 URL 时去掉 `/v1`;Anthropic adapter 接收 `/v1/messages`、`/v1/models`。因此 OpenAI 供应商 `base_url` 配置到版本路径一级(如 `https://api.openai.com/v1`),Anthropic 供应商 `base_url` 配置到域名级(如 `https://api.anthropic.com`)。
|
||||
|
||||
代理错误边界:网关层错误统一返回 `{"error":"...","code":"..."}`,例如 `INVALID_JSON`、`MODEL_NOT_FOUND`、`CONVERSION_FAILED`、`UPSTREAM_UNAVAILABLE`。只要上游已经返回 HTTP 响应,非 2xx 的 status、过滤 hop-by-hop header 后的 headers 和 body 会直接透传,不包装为应用错误或协议错误。
|
||||
|
||||
模型路由边界:只有 adapter 明确适配的接口会解析请求体中的 `model` 并使用统一模型 ID 路由;未知接口即使包含顶层 `model` 也按无 model 透传处理。
|
||||
|
||||
流式边界:同协议无响应 model 改写时原样透传 SSE frame 和 `[DONE]`;同协议需要响应 model 改写时只解析 SSE frame 的 `data` JSON 并改写 `model`;跨协议流式仍走 provider decoder → Canonical stream event → client encoder。
|
||||
|
||||
### 管理接口(对前端)
|
||||
|
||||
#### 供应商管理
|
||||
@@ -189,6 +199,9 @@ bun dev
|
||||
|
||||
查询参数支持:`provider_id`、`model_name`、`start_date`、`end_date`、`group_by`
|
||||
|
||||
#### 版本信息
|
||||
- `GET /api/version` - 获取后端构建版本信息(`version`、`commit`、`build_time`),用于前端 About 页面诊断前后端版本一致性
|
||||
|
||||
## 配置
|
||||
|
||||
配置支持多种方式,优先级为:**CLI 参数 > 环境变量 > 配置文件 > 默认值**
|
||||
@@ -262,25 +275,100 @@ export NEX_DATABASE_DBNAME=nex
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
make backend-test # 后端测试
|
||||
make backend-test-coverage # 后端覆盖率
|
||||
make frontend-test # 前端测试
|
||||
make frontend-test-e2e # 前端 E2E 测试
|
||||
# 全局默认测试(不含 MySQL 和前端 E2E)
|
||||
make test
|
||||
|
||||
# 产品级测试
|
||||
make server-test
|
||||
make desktop-test
|
||||
```
|
||||
|
||||
backend 分类测试、MySQL 专项测试和前端 E2E 测试请分别查看 `backend/README.md` 与 `frontend/README.md`。
|
||||
|
||||
## 开发
|
||||
|
||||
```bash
|
||||
make backend-build # 构建后端
|
||||
make backend-run # 运行后端
|
||||
make backend-lint # 后端代码检查
|
||||
make backend-migrate-up # 数据库迁移
|
||||
# 首次克隆后安装 Git hooks
|
||||
lefthook install
|
||||
|
||||
make frontend-build # 构建前端
|
||||
make frontend-dev # 前端开发模式
|
||||
make frontend-lint # 前端代码检查
|
||||
# 全局命令
|
||||
make lint # 前后端共享检查
|
||||
make test # 默认全量测试(不含 MySQL/E2E)
|
||||
make clean # 清理所有构建产物和测试报告
|
||||
|
||||
# server 模式
|
||||
make server-run # 并行启动后端和前端开发服务
|
||||
make server-build # 构建 backend/bin/server 和 frontend/dist
|
||||
make server-lint # server 模式检查
|
||||
make server-test # server 模式测试
|
||||
make server-clean # 清理 server 模式产物
|
||||
|
||||
# desktop 模式
|
||||
make desktop-build-mac # 构建 macOS 桌面应用
|
||||
make desktop-build-win # 构建 Windows 桌面应用
|
||||
make desktop-build-linux # 构建 Linux 桌面应用
|
||||
make desktop-lint # desktop 模式检查
|
||||
make desktop-test # desktop 专属测试
|
||||
make desktop-clean # 清理 desktop 产物
|
||||
```
|
||||
|
||||
## 版本与发布
|
||||
|
||||
### 统一版本源
|
||||
|
||||
- 仓库根目录 `VERSION` 是全仓唯一版本源,格式固定为 `x.y.z`
|
||||
- `frontend/package.json` 和前端 `.env.*` 中的 `VITE_APP_VERSION` 由仓库工具同步,不能手工漂移
|
||||
|
||||
### 本地版本演进
|
||||
|
||||
```bash
|
||||
# 递增版本(自动 sync + check + commit + tag)
|
||||
make version-bump BUMP=minor
|
||||
|
||||
# 或指定具体版本号
|
||||
make version-bump SET_VERSION=1.0.0
|
||||
|
||||
# 推送到远程
|
||||
git push --follow-tags
|
||||
```
|
||||
|
||||
手动同步和校验:
|
||||
|
||||
```bash
|
||||
make version-sync
|
||||
make version-check
|
||||
```
|
||||
|
||||
### 本地生成发布资产
|
||||
|
||||
```bash
|
||||
# Linux: server + desktop
|
||||
make release-assets-linux
|
||||
|
||||
# Windows: server + desktop(需在 Windows 环境执行)
|
||||
make release-assets-windows
|
||||
|
||||
# macOS: darwin-amd64 server、darwin-arm64 server、desktop universal
|
||||
make release-assets-macos
|
||||
```
|
||||
|
||||
生成的版本化发布资产位于 `build/release/`。
|
||||
|
||||
### GitHub Draft Release
|
||||
|
||||
- 推送 `vX.Y.Z` tag 后,`.github/workflows/release.yml` 会自动执行发布流水线
|
||||
- 三个平台 job 会在正式构建前先检查 `go`、`bun` 和各自的平台打包工具链,缺失时快速失败并在日志中输出诊断信息
|
||||
- Windows 发布 job 在 `MSYS2 / MINGW64` shell 中执行,并继承 `setup-go` / `setup-bun` 准备好的工具链路径
|
||||
- 流水线会先校验 tag 与 `VERSION` 一致,再构建以下资产并上传到 GitHub Draft Release:
|
||||
- Linux server
|
||||
- Windows server
|
||||
- darwin-amd64 server
|
||||
- darwin-arm64 server
|
||||
- Linux desktop
|
||||
- Windows desktop
|
||||
- macOS desktop universal
|
||||
- Release 默认以 Draft 形式创建,需人工检查后再公开发布
|
||||
|
||||
## 开发规范
|
||||
|
||||
详见各子项目的 README.md:
|
||||
@@ -289,4 +377,4 @@ make frontend-lint # 前端代码检查
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
Apache License 2.0
|
||||
|
||||
Binary file not shown.
@@ -1,64 +0,0 @@
|
||||
# Assets
|
||||
|
||||
应用资源文件目录。
|
||||
|
||||
## 文件说明
|
||||
|
||||
| 文件 | 用途 | 尺寸 | 格式 |
|
||||
|------|------|------|------|
|
||||
| `icon.svg` | 源图标 | 64x64 | SVG |
|
||||
| `icon.png` | 托盘图标 | 64x64 | PNG |
|
||||
| `AppIcon.icns` | macOS 应用图标 | 多尺寸 | ICNS |
|
||||
| `icon.ico` | Windows 应用图标 | 256x256 | ICO |
|
||||
|
||||
## 替换图标
|
||||
|
||||
### 1. 准备图标
|
||||
|
||||
推荐使用 SVG 格式的源图标,尺寸至少 256x256。
|
||||
|
||||
### 2. 生成各平台图标
|
||||
|
||||
**托盘图标 (PNG)**:
|
||||
```bash
|
||||
magick your-icon.svg -resize 64x64 icon.png
|
||||
```
|
||||
|
||||
**macOS 应用图标 (ICNS)**:
|
||||
```bash
|
||||
mkdir icon.iconset
|
||||
magick your-icon.svg -resize 16x16 icon.iconset/icon_16x16.png
|
||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_16x16@2x.png
|
||||
magick your-icon.svg -resize 32x32 icon.iconset/icon_32x32.png
|
||||
magick your-icon.svg -resize 64x64 icon.iconset/icon_32x32@2x.png
|
||||
magick your-icon.svg -resize 128x128 icon.iconset/icon_128x128.png
|
||||
magick your-icon.svg -resize 256x256 icon.iconset/icon_128x128@2x.png
|
||||
iconutil -c icns icon.iconset -o AppIcon.icns
|
||||
rm -rf icon.iconset
|
||||
```
|
||||
|
||||
**Windows 应用图标 (ICO)**:
|
||||
```bash
|
||||
magick your-icon.svg -resize 256x256 icon.ico
|
||||
```
|
||||
|
||||
### 3. 替换文件
|
||||
|
||||
将生成的文件放入此目录,然后重新构建桌面应用:
|
||||
```bash
|
||||
./scripts/build/build-darwin-arm64.sh
|
||||
```
|
||||
|
||||
## macOS Template 图标
|
||||
|
||||
macOS 支持 Template 图标,自动适配深浅色模式:
|
||||
- 使用黑色 + 透明设计
|
||||
- 文件名以 `Template` 结尾(如 `iconTemplate.png`)
|
||||
- 黑色在深色模式下自动变为白色
|
||||
|
||||
## 设计建议
|
||||
|
||||
- 托盘图标应简洁,在小尺寸下清晰可辨
|
||||
- 避免过多细节和文字
|
||||
- 使用高对比度颜色
|
||||
- macOS 建议使用 Template 图标风格
|
||||
BIN
assets/icon.icns
LFS
Normal file
BIN
assets/icon.icns
LFS
Normal file
Binary file not shown.
BIN
assets/icon.ico
BIN
assets/icon.ico
Binary file not shown.
|
Before Width: | Height: | Size: 264 KiB After Width: | Height: | Size: 130 B |
BIN
assets/icon.png
BIN
assets/icon.png
Binary file not shown.
|
Before Width: | Height: | Size: 2.0 KiB After Width: | Height: | Size: 131 B |
@@ -1,13 +0,0 @@
|
||||
<svg width="64" height="64" viewBox="0 0 64 64" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="64" height="64" rx="12" fill="#4A90D9"/>
|
||||
<polygon points="32,8 52,20 52,44 32,56 12,44 12,20" fill="none" stroke="white" stroke-width="3"/>
|
||||
<circle cx="32" cy="32" r="6" fill="white"/>
|
||||
<line x1="32" y1="32" x2="20" y2="20" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="44" y2="20" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="20" y2="44" stroke="white" stroke-width="2"/>
|
||||
<line x1="32" y1="32" x2="44" y2="44" stroke="white" stroke-width="2"/>
|
||||
<circle cx="20" cy="20" r="3" fill="white"/>
|
||||
<circle cx="44" cy="20" r="3" fill="white"/>
|
||||
<circle cx="20" cy="44" r="3" fill="white"/>
|
||||
<circle cx="44" cy="44" r="3" fill="white"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 779 B |
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/128x128/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/16x16/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/22x22/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/24x24/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/256x256/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/32x32/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/48x48/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/512x512/apps/nex.png
LFS
Normal file
Binary file not shown.
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
BIN
assets/icons/hicolor/64x64/apps/nex.png
LFS
Normal file
Binary file not shown.
91
backend/.golangci.yml
Normal file
91
backend/.golangci.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
- forbidigo
|
||||
- errorlint
|
||||
- errcheck
|
||||
- staticcheck
|
||||
- revive
|
||||
- gocritic
|
||||
- gosec
|
||||
- bodyclose
|
||||
- noctx
|
||||
- nilerr
|
||||
- goimports
|
||||
- gocyclo
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-blank: true
|
||||
check-type-assertions: true
|
||||
exclude-functions:
|
||||
- fmt.Fprintf
|
||||
forbidigo:
|
||||
analyze-types: true
|
||||
forbid:
|
||||
- p: '^fmt\.Print.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^fmt\.Fprint.*$'
|
||||
msg: 使用 zap logger,不要直接输出到 stdout/stderr
|
||||
- p: '^log\.(Print|Println|Printf|Fatal|Fatalln|Fatalf|Panic|Panicln|Panicf)$'
|
||||
msg: 使用 zap logger,不要使用标准库 log
|
||||
- p: '^zap\.L$'
|
||||
msg: 通过依赖注入传递 *zap.Logger,不要使用全局 logger
|
||||
- p: '^zap\.S$'
|
||||
msg: 不使用 Sugar logger
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
- name: var-naming
|
||||
- name: indent-error-flow
|
||||
- name: error-strings
|
||||
- name: error-return
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: unexported-return
|
||||
goimports:
|
||||
local-prefixes: nex/backend
|
||||
gocyclo:
|
||||
min-complexity: 10
|
||||
|
||||
issues:
|
||||
exclude-dirs:
|
||||
- tests/mocks
|
||||
exclude-generated: true
|
||||
exclude-rules:
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- forbidigo
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- errcheck
|
||||
source: '(^\s*_\s*=|,\s*_)'
|
||||
- path: 'tests/integration/e2e_conversion_test\.go'
|
||||
linters:
|
||||
- errcheck
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- revive
|
||||
text: '^exported:'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gosec
|
||||
text: 'G(101|401|501)'
|
||||
- path: '(_test\.go|tests/)'
|
||||
linters:
|
||||
- gocyclo
|
||||
text: 'cyclomatic complexity (1[1-9]|20) of .* is high \(> 10\)'
|
||||
- linters:
|
||||
- revive
|
||||
text: '(that stutters|BuildUrl should be BuildURL|ConvertHttpRequest should be ConvertHTTPRequest|ConvertHttpResponse should be ConvertHTTPResponse)'
|
||||
- path: 'internal/conversion/.*\.go'
|
||||
linters:
|
||||
- gocyclo
|
||||
- gocritic
|
||||
- path: '(internal/provider/client\.go|internal/service/model_service_impl\.go|internal/service/stats_buffer\.go|internal/handler/proxy_handler\.go|cmd/(desktop|server)/main\.go)'
|
||||
linters:
|
||||
- gocyclo
|
||||
97
backend/Makefile
Normal file
97
backend/Makefile
Normal file
@@ -0,0 +1,97 @@
|
||||
.PHONY: \
|
||||
build run \
|
||||
test test-unit test-integration test-coverage \
|
||||
lint clean \
|
||||
migrate-up migrate-down migrate-status migrate-create \
|
||||
mysql-up mysql-down mysql-test mysql-test-quick
|
||||
|
||||
VERSION := $(shell go run ../versionctl print)
|
||||
GIT_COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || printf 'unknown')
|
||||
BUILD_TIME ?= $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
|
||||
GO_LDFLAGS := -X nex/backend/pkg/buildinfo.version=$(VERSION) -X nex/backend/pkg/buildinfo.commit=$(GIT_COMMIT) -X nex/backend/pkg/buildinfo.buildTime=$(BUILD_TIME)
|
||||
|
||||
DB_DRIVER ?= sqlite3
|
||||
DB_DSN ?= $(HOME)/.nex/config.db
|
||||
|
||||
ifeq ($(DB_DRIVER),mysql)
|
||||
GOOSE_DIR := migrations/mysql
|
||||
GOOSE_DRIVER := mysql
|
||||
else ifeq ($(DB_DRIVER),sqlite3)
|
||||
GOOSE_DIR := migrations/sqlite
|
||||
GOOSE_DRIVER := sqlite3
|
||||
else
|
||||
$(error unsupported DB_DRIVER '$(DB_DRIVER)', use sqlite3 or mysql)
|
||||
endif
|
||||
|
||||
build:
|
||||
go build -ldflags "$(GO_LDFLAGS)" -o bin/server ./cmd/server
|
||||
|
||||
run:
|
||||
go run -ldflags "$(GO_LDFLAGS)" ./cmd/server
|
||||
|
||||
test:
|
||||
go test ./internal/... ./pkg/... ./tests/... ./cmd/server/... -v
|
||||
|
||||
test-unit:
|
||||
go test ./internal/... ./pkg/... -v
|
||||
|
||||
test-integration:
|
||||
go test ./tests/... -v
|
||||
|
||||
test-coverage:
|
||||
go test ./... -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@printf 'Coverage report generated: backend/coverage.html\n'
|
||||
|
||||
lint:
|
||||
go tool golangci-lint run ./...
|
||||
|
||||
clean:
|
||||
rm -rf bin/ coverage.out coverage.html
|
||||
|
||||
migrate-up:
|
||||
@printf 'Running database migration up...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" up
|
||||
|
||||
migrate-down:
|
||||
@printf 'Running database migration down...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" down
|
||||
|
||||
migrate-status:
|
||||
@printf 'Checking database migration status...\n'
|
||||
goose -dir $(GOOSE_DIR) $(GOOSE_DRIVER) "$(DB_DSN)" status
|
||||
|
||||
migrate-create:
|
||||
@printf 'Migration name: '; \
|
||||
read name; \
|
||||
goose -dir migrations/sqlite create $$name sql; \
|
||||
goose -dir migrations/mysql create $$name sql
|
||||
|
||||
mysql-up:
|
||||
@printf 'Starting MySQL test container...\n'
|
||||
cd tests/mysql && docker-compose up -d
|
||||
@printf 'Waiting for MySQL to be ready...\n'
|
||||
@for i in $$(seq 1 30); do \
|
||||
if docker exec nex-mysql-test mysqladmin ping -h localhost -u root -ptestpass --silent 2>/dev/null; then \
|
||||
printf 'MySQL is ready\n'; \
|
||||
exit 0; \
|
||||
fi; \
|
||||
printf 'Waiting... (%s/30)\n' $$i; \
|
||||
sleep 1; \
|
||||
done; \
|
||||
printf 'MySQL failed to start\n'; \
|
||||
exit 1
|
||||
|
||||
mysql-down:
|
||||
@printf 'Stopping MySQL test container...\n'
|
||||
cd tests/mysql && docker-compose down -v
|
||||
|
||||
mysql-test:
|
||||
@set -e; \
|
||||
$(MAKE) mysql-up; \
|
||||
trap '$(MAKE) mysql-down' EXIT; \
|
||||
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
|
||||
mysql-test-quick:
|
||||
@printf 'Running MySQL tests without container management...\n'
|
||||
go test -tags=mysql ./tests/mysql/... -v -count=1
|
||||
@@ -4,21 +4,67 @@ AI 网关后端服务,提供统一的大模型 API 代理接口。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持 OpenAI 协议(`/openai/v1/...`)
|
||||
- 支持 OpenAI 协议(`/openai/v1/...`,例如 `/openai/v1/chat/completions`)
|
||||
- 支持 Anthropic 协议(`/anthropic/v1/...`)
|
||||
- 支持 Hub-and-Spoke 跨协议双向转换(OpenAI ↔ Anthropic)
|
||||
- 同协议透传(零语义损失、零序列化开销)
|
||||
- 同协议透传(跳过 Canonical 全量转换,保持协议语义)
|
||||
- 支持流式响应(SSE)
|
||||
- 支持 Function Calling / Tools
|
||||
- 支持 Thinking / Reasoning
|
||||
- 支持扩展层接口(Models、Embeddings、Rerank)
|
||||
- 多供应商配置和路由
|
||||
- 用量统计
|
||||
- 结构化日志(zap + lumberjack)
|
||||
- 结构化日志(zap + lumberjack + 模块标识)
|
||||
- YAML 配置管理
|
||||
- 请求验证
|
||||
- 中间件支持(请求 ID、日志、恢复、CORS)
|
||||
|
||||
## 日志规范
|
||||
|
||||
### 模块标识
|
||||
|
||||
每个模块通过依赖注入获取带模块标识的 logger:
|
||||
|
||||
```go
|
||||
func NewProxyHandler(..., logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
输出格式:
|
||||
- Console: `INFO [handler.proxy] 处理请求 method=POST path=/v1/chat`
|
||||
- JSON: `{"level":"info","logger":"handler.proxy","msg":"处理请求"}`
|
||||
|
||||
### 模块命名规范
|
||||
|
||||
| 模块 | 命名 |
|
||||
|------|------|
|
||||
| ProxyHandler | `handler.proxy` |
|
||||
| ProviderHandler | `handler.provider` |
|
||||
| Provider Client | `provider.client` |
|
||||
| ConversionEngine | `conversion.engine` |
|
||||
| RoutingCache | `service.routing_cache` |
|
||||
| StatsBuffer | `service.stats_buffer` |
|
||||
| Database | `database` |
|
||||
|
||||
### 标准字段
|
||||
|
||||
使用 `pkg/logger/field.go` 中定义的字段构造函数:
|
||||
|
||||
```go
|
||||
logger.Debug("请求开始",
|
||||
pkglogger.Method("POST"),
|
||||
pkglogger.Path("/v1/chat"),
|
||||
pkglogger.RequestID("xxx"),
|
||||
)
|
||||
```
|
||||
|
||||
### GORM 日志
|
||||
|
||||
GORM 日志自动桥接到 zap,SQL 查询映射到 Debug 级别。
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **语言**: Go 1.26+
|
||||
@@ -105,9 +151,13 @@ backend/
|
||||
│ │ ├── errors.go
|
||||
│ │ └── wrap.go
|
||||
│ ├── logger/ # 日志系统
|
||||
│ │ ├── logger.go
|
||||
│ │ ├── rotate.go
|
||||
│ │ └── context.go
|
||||
│ │ ├── logger.go # 核心初始化
|
||||
│ │ ├── field.go # 标准字段定义
|
||||
│ │ ├── module.go # 模块日志器
|
||||
│ │ ├── context.go # Context 辅助函数
|
||||
│ │ ├── gorm.go # GORM 适配器
|
||||
│ │ ├── minimal.go # 最小化 logger
|
||||
│ │ └── rotate.go # 日志轮转
|
||||
│ ├── modelid/ # 统一模型 ID 工具包
|
||||
│ │ ├── model_id.go
|
||||
│ │ └── model_id_test.go
|
||||
@@ -170,7 +220,7 @@ OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
|
||||
### Smart Passthrough 机制
|
||||
|
||||
同协议请求走 Smart Passthrough 路径,**零序列化开销**:
|
||||
同协议请求走 Smart Passthrough 路径,不进入 Canonical 全量转换:
|
||||
|
||||
```
|
||||
1. 检测 clientProtocol == providerProtocol
|
||||
@@ -179,12 +229,14 @@ OpenAI Response ← Canonical Response ← Anthropic Response
|
||||
4. 响应中仅改写 model 字段:upstream_model_name → unified_id
|
||||
```
|
||||
|
||||
Smart Passthrough 保持未改写 JSON 字段的内容和类型不变,不承诺保留原始字节顺序、空白或对象字段顺序。
|
||||
|
||||
### 流式转换器层次
|
||||
|
||||
```
|
||||
StreamConverter (接口)
|
||||
├── PassthroughStreamConverter # 直接透传,无任何处理
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 逐 chunk 改写 model
|
||||
├── SmartPassthroughStreamConverter # 同协议 + 按 SSE frame 改写 data JSON model
|
||||
└── CanonicalStreamConverter # 跨协议完整转换(decode → encode)
|
||||
```
|
||||
|
||||
@@ -251,6 +303,7 @@ StreamConverter (接口)
|
||||
| `PROTOCOL_CONSTRAINT_VIOLATION` | 协议约束违反 |
|
||||
| `ENCODING_FAILURE` | 编码失败 |
|
||||
| `INTERFACE_NOT_SUPPORTED` | 接口不支持(如 Anthropic Embeddings) |
|
||||
| `UNSUPPORTED_MULTIMODAL` | 跨协议暂不支持多模态内容 |
|
||||
|
||||
### AppError 预定义错误
|
||||
|
||||
@@ -384,24 +437,37 @@ docker run -d -e NEX_SERVER_PORT=9000 -e NEX_LOG_LEVEL=info nex-server
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
# 运行 backend 默认测试
|
||||
make test
|
||||
|
||||
# 分类测试
|
||||
make test-unit
|
||||
make test-integration
|
||||
|
||||
# 生成覆盖率报告
|
||||
make test-coverage
|
||||
|
||||
# MySQL 专项测试
|
||||
make mysql-up
|
||||
make mysql-down
|
||||
make mysql-test
|
||||
make mysql-test-quick
|
||||
```
|
||||
|
||||
## 数据库迁移
|
||||
|
||||
```bash
|
||||
# 使用 Makefile
|
||||
make migrate-up DB_PATH=~/.nex/config.db
|
||||
make migrate-down DB_PATH=~/.nex/config.db
|
||||
make migrate-status DB_PATH=~/.nex/config.db
|
||||
make migrate-up DB_DSN=~/.nex/config.db
|
||||
make migrate-down DB_DSN=~/.nex/config.db
|
||||
make migrate-status DB_DSN=~/.nex/config.db
|
||||
|
||||
# 创建新迁移
|
||||
make migrate-create
|
||||
|
||||
# MySQL 迁移
|
||||
make migrate-up DB_DRIVER=mysql DB_DSN='user:pass@tcp(localhost:3306)/nex'
|
||||
|
||||
# 或直接使用 goose
|
||||
goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
```
|
||||
@@ -410,15 +476,15 @@ goose -dir migrations sqlite3 ~/.nex/config.db up
|
||||
|
||||
### 代理接口
|
||||
|
||||
使用 `/{protocol}/v1/{path}` URL 前缀路由:
|
||||
使用 `/{protocol}/*path` URL 前缀路由。网关只剥离第一段协议前缀,不在 Handler 中统一添加或移除 `/v1`;剩余 path 是协议原生 nativePath,由对应 adapter 识别和组合上游 URL。
|
||||
|
||||
#### OpenAI 协议
|
||||
|
||||
```
|
||||
POST /openai/chat/completions
|
||||
GET /openai/models
|
||||
POST /openai/embeddings
|
||||
POST /openai/rerank
|
||||
POST /openai/v1/chat/completions
|
||||
GET /openai/v1/models
|
||||
POST /openai/v1/embeddings
|
||||
POST /openai/v1/rerank
|
||||
```
|
||||
|
||||
#### Anthropic 协议
|
||||
@@ -428,10 +494,20 @@ POST /anthropic/v1/messages
|
||||
GET /anthropic/v1/models
|
||||
```
|
||||
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传,零序列化开销。
|
||||
**协议转换**:网关支持任意协议间的双向转换。客户端使用 OpenAI 协议请求,上游供应商可以是 Anthropic 协议(反之亦然)。同协议时自动透传或 Smart Passthrough,跳过 Canonical 全量转换。
|
||||
|
||||
**统一模型 ID**:代理请求中的 `model` 字段使用 `provider_id/model_name` 格式(如 `openai/gpt-4`),网关据此路由到对应供应商。同协议时自动改写为上游 `model_name`,跨协议时通过全量转换处理。
|
||||
|
||||
**base_url 约定**:
|
||||
- OpenAI 供应商配置到版本路径一级,例如 `https://api.openai.com/v1`;当客户端请求 `/openai/v1/chat/completions` 时,OpenAI adapter 会把 nativePath `/v1/chat/completions` 映射为上游 path `/chat/completions`,最终 URL 为 `https://api.openai.com/v1/chat/completions`。
|
||||
- Anthropic 供应商配置到域名级,例如 `https://api.anthropic.com`。
|
||||
|
||||
**模型提取边界**:只有 adapter 明确适配的 Chat、Embeddings、Rerank 等接口会提取 `model` 并尝试统一模型 ID 路由。未知接口不做顶层 `model` 猜测,直接按无 model 透传。
|
||||
|
||||
**流式透传边界**:同协议无响应 model 改写时 raw passthrough,保留 SSE frame 边界和 `[DONE]`;同协议需要改写时按 SSE frame 解析 `data` JSON,仅改写 `model`;跨协议继续使用 StreamDecoder → CanonicalStreamConverter → StreamEncoder。
|
||||
|
||||
**错误边界**:网关层代理错误返回 `{"error":"...","code":"..."}`。已收到上游 HTTP 响应时,非 2xx status、过滤 hop-by-hop header 后的 headers 和 body 直接透传;没有收到上游响应的连接/DNS/TLS/超时错误返回 `UPSTREAM_UNAVAILABLE`。
|
||||
|
||||
### 管理接口
|
||||
|
||||
#### 供应商管理
|
||||
@@ -459,7 +535,7 @@ GET /anthropic/v1/models
|
||||
- Anthropic 协议:配置到域名,不包含版本路径,如 `https://api.anthropic.com`
|
||||
|
||||
**对外 URL 格式**:
|
||||
- OpenAI 协议:`/{protocol}/{endpoint}`,如 `/openai/chat/completions`、`/openai/models`、`/openai/embeddings`
|
||||
- OpenAI 协议:`/{protocol}/v1/{endpoint}`,如 `/openai/v1/chat/completions`、`/openai/v1/models`、`/openai/v1/embeddings`
|
||||
- Anthropic 协议:`/{protocol}/v1/{endpoint}`,如 `/anthropic/v1/messages`、`/anthropic/v1/models`
|
||||
|
||||
#### 模型管理
|
||||
@@ -501,6 +577,20 @@ GET /anthropic/v1/models
|
||||
|
||||
查询参数:`provider_id`、`model_name`、`start_date`(YYYY-MM-DD)、`end_date`、`group_by`(provider/model/date)
|
||||
|
||||
#### 版本信息
|
||||
|
||||
- `GET /api/version` - 获取后端构建版本信息
|
||||
|
||||
响应字段来源于构建阶段注入的 `buildinfo` 元数据:
|
||||
|
||||
```json
|
||||
{
|
||||
"version": "0.1.0",
|
||||
"commit": "abc1234",
|
||||
"build_time": "2026-05-05T00:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
#### 健康检查
|
||||
|
||||
- `GET /health` - 返回 `{"status": "ok"}`
|
||||
@@ -508,9 +598,12 @@ GET /anthropic/v1/models
|
||||
## 开发
|
||||
|
||||
```bash
|
||||
make build # 构建
|
||||
make build # 构建 backend/bin/server
|
||||
make run # 运行后端服务
|
||||
make lint # 代码检查
|
||||
make deps # 整理依赖
|
||||
make clean # 清理 backend 构建产物
|
||||
go mod tidy # 整理依赖
|
||||
go generate ./... # 刷新 mock 等生成代码
|
||||
```
|
||||
|
||||
环境要求:Go 1.26 或更高版本
|
||||
@@ -559,6 +652,7 @@ err := v.Validate(myStruct)
|
||||
|
||||
- **JSON 解析**:使用 `encoding/json` 标准库,不手动扫描字节
|
||||
- **字符串拼接**:使用 `strings.Join`,不手写循环拼接
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配
|
||||
- **错误判断**:使用 `errors.Is` / `errors.As`,不使用字符串匹配(lint 强约束:errorlint 禁止 `err ==` 直接比较和 `err.(*T)` 直接断言)
|
||||
- **日志使用**:通过依赖注入 `*zap.Logger`,不直接调用 `zap.L()`
|
||||
- **日志 error 字段**:使用 `zap.Error(err)`,不使用 `zap.String("error", err.Error())` 手工字符串化
|
||||
- **字符串分割**:使用 `strings.SplitN` 等精确分割,不使用索引切片
|
||||
|
||||
25
backend/cmd/desktop/dialog_darwin.go
Normal file
25
backend/cmd/desktop/dialog_darwin.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`,
|
||||
escapeAppleScript(message), escapeAppleScript(title))
|
||||
if err := exec.Command("osascript", "-e", script).Run(); err != nil {
|
||||
dialogLogger().Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func escapeAppleScript(s string) string {
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return s
|
||||
}
|
||||
67
backend/cmd/desktop/dialog_linux.go
Normal file
67
backend/cmd/desktop/dialog_linux.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type dialogToolType int
|
||||
|
||||
const (
|
||||
toolNone dialogToolType = iota
|
||||
toolZenity
|
||||
toolKdialog
|
||||
toolNotifySend
|
||||
toolXmessage
|
||||
)
|
||||
|
||||
var (
|
||||
dialogTool dialogToolType
|
||||
dialogToolOnce sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
dialogToolOnce.Do(detectDialogTool)
|
||||
}
|
||||
|
||||
func detectDialogTool() {
|
||||
tools := []struct {
|
||||
name string
|
||||
typ dialogToolType
|
||||
}{
|
||||
{"zenity", toolZenity},
|
||||
{"kdialog", toolKdialog},
|
||||
{"notify-send", toolNotifySend},
|
||||
{"xmessage", toolXmessage},
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if _, err := exec.LookPath(tool.name); err == nil {
|
||||
dialogTool = tool.typ
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dialogTool = toolNone
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
switch dialogTool {
|
||||
case toolZenity:
|
||||
exec.Command("zenity", "--error",
|
||||
fmt.Sprintf("--title=%s", title),
|
||||
fmt.Sprintf("--text=%s", message)).Run()
|
||||
case toolKdialog:
|
||||
exec.Command("kdialog", "--error", message, "--title", title).Run()
|
||||
case toolNotifySend:
|
||||
exec.Command("notify-send", "-u", "critical", title, message).Run()
|
||||
case toolXmessage:
|
||||
exec.Command("xmessage", "-center",
|
||||
fmt.Sprintf("%s: %s", title, message)).Run()
|
||||
default:
|
||||
dialogLogger().Error("无法显示错误对话框")
|
||||
}
|
||||
}
|
||||
15
backend/cmd/desktop/dialog_logger.go
Normal file
15
backend/cmd/desktop/dialog_logger.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func dialogLogger() *zap.Logger {
|
||||
if zapLogger != nil {
|
||||
return zapLogger
|
||||
}
|
||||
|
||||
return pkgLogger.NewMinimal()
|
||||
}
|
||||
62
backend/cmd/desktop/dialog_windows.go
Normal file
62
backend/cmd/desktop/dialog_windows.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
mbIconError = 0x10
|
||||
mbIconInformation = 0x40
|
||||
)
|
||||
|
||||
var (
|
||||
user32 = syscall.NewLazyDLL("user32.dll")
|
||||
procMessageBoxW = user32.NewProc("MessageBoxW")
|
||||
callMessageBoxW = func(hwnd, text, caption, flags uintptr) (uintptr, error) {
|
||||
ret, _, err := procMessageBoxW.Call(hwnd, text, caption, flags)
|
||||
return ret, err
|
||||
}
|
||||
)
|
||||
|
||||
func showError(title, message string) {
|
||||
if err := messageBox(title, message, mbIconError); err != nil {
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("显示错误对话框失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func messageBox(title, message string, flags uint) error {
|
||||
titlePtr, err := syscall.UTF16PtrFromString(title)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messagePtr, err := syscall.UTF16PtrFromString(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret, callErr := callMessageBoxW(
|
||||
0,
|
||||
uintptr(unsafe.Pointer(messagePtr)),
|
||||
uintptr(unsafe.Pointer(titlePtr)),
|
||||
uintptr(flags),
|
||||
)
|
||||
if ret != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if callErr != nil && !errors.Is(callErr, syscall.Errno(0)) {
|
||||
return callErr
|
||||
}
|
||||
|
||||
return fmt.Errorf("MessageBoxW 调用失败")
|
||||
}
|
||||
33
backend/cmd/desktop/icon_test.go
Normal file
33
backend/cmd/desktop/icon_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"nex/embedfs"
|
||||
)
|
||||
|
||||
func TestIconSelection_Windows(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("图标格式选择测试仅在 Windows 上运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.ico"); err != nil {
|
||||
t.Fatalf("Windows 应加载 .ico 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIconSelection_NonWindows(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("图标格式选择测试在非 Windows 平台运行")
|
||||
}
|
||||
|
||||
if err := testIconLoad("assets/icon.png"); err != nil {
|
||||
t.Fatalf("非 Windows 平台应加载 .png 文件: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testIconLoad(path string) error {
|
||||
_, err := embedfs.Assets.ReadFile(path)
|
||||
return err
|
||||
}
|
||||
1
backend/cmd/desktop/icon_windows.rc
Normal file
1
backend/cmd/desktop/icon_windows.rc
Normal file
@@ -0,0 +1 @@
|
||||
1 ICON "../../../assets/icon.ico"
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -14,10 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
"nex/embedfs"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -29,9 +25,14 @@ import (
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
"nex/backend/pkg/buildinfo"
|
||||
|
||||
"nex/embedfs"
|
||||
"github.com/getlantern/systray"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/flock"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -44,25 +45,32 @@ var (
|
||||
func main() {
|
||||
port := 9826
|
||||
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
singleLock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway.lock"))
|
||||
if err := singleLock.Lock(); err != nil {
|
||||
showError("Nex Gateway", "已有 Nex 实例运行")
|
||||
minimalLogger.Error("已有 Nex 实例运行")
|
||||
showError(appName, "已有 Nex 实例运行")
|
||||
os.Exit(1)
|
||||
}
|
||||
defer singleLock.Unlock()
|
||||
defer func() {
|
||||
if err := singleLock.Unlock(); err != nil {
|
||||
minimalLogger.Warn("释放实例锁失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkPortAvailable(port); err != nil {
|
||||
showError("Nex Gateway", err.Error())
|
||||
os.Exit(1)
|
||||
minimalLogger.Error("端口不可用", zap.Error(err))
|
||||
showError(appName, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
showError("Nex Gateway", fmt.Sprintf("加载配置失败: %v", err))
|
||||
os.Exit(1)
|
||||
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||
}
|
||||
|
||||
zapLogger, err = pkgLogger.New(pkgLogger.Config{
|
||||
zapLogger, err = pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
@@ -71,15 +79,19 @@ func main() {
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
showError("Nex Gateway", fmt.Sprintf("初始化日志失败: %v", err))
|
||||
os.Exit(1)
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
showError("Nex Gateway", fmt.Sprintf("初始化数据库失败: %v", err))
|
||||
os.Exit(1)
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer database.Close(db)
|
||||
|
||||
@@ -105,19 +117,20 @@ func main() {
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
providerClient := provider.NewClient()
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
versionHandler := handler.NewVersionHandler()
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
@@ -127,7 +140,7 @@ func main() {
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
|
||||
setupStaticFiles(r)
|
||||
|
||||
server = &http.Server{
|
||||
@@ -140,24 +153,30 @@ func main() {
|
||||
shutdownCtx, shutdownCancel = context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", server.Addr))
|
||||
zapLogger.Info("AI Gateway 启动",
|
||||
zap.String("addr", server.Addr),
|
||||
zap.String("version", buildinfo.Version()),
|
||||
zap.String("commit", buildinfo.Commit()),
|
||||
zap.String("build_time", buildinfo.BuildTime()))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("无法打开浏览器", zap.String("error", err.Error()))
|
||||
zapLogger.Warn("无法打开浏览器", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
setupSystray(port)
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
r.Any("/v1/*path", proxyHandler.HandleProxy)
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
|
||||
r.Any("/openai/*path", withProtocol("openai", proxyHandler.HandleProxy))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", proxyHandler.HandleProxy))
|
||||
r.GET("/api/version", versionHandler.GetVersion)
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
@@ -188,12 +207,26 @@ func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHand
|
||||
})
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.String("error", err.Error()))
|
||||
func withProtocol(protocol string, next gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Params = append(c.Params, gin.Param{Key: "protocol", Value: protocol})
|
||||
next(c)
|
||||
}
|
||||
}
|
||||
|
||||
func setupStaticFiles(r *gin.Engine) {
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
zapLogger.Fatal("无法加载前端资源", zap.Error(err))
|
||||
}
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
}
|
||||
|
||||
func frontendDistFS() (fs.FS, error) {
|
||||
return fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
}
|
||||
|
||||
func setupStaticFilesWithFS(r *gin.Engine, distFS fs.FS) {
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
@@ -226,20 +259,23 @@ func setupStaticFiles(r *gin.Engine) {
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
r.GET("/icon.png", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "icon.png")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
c.Data(200, "image/png", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/openai/") ||
|
||||
strings.HasPrefix(path, "/anthropic/") ||
|
||||
path == "/openai" ||
|
||||
path == "/anthropic" ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
@@ -256,13 +292,18 @@ func setupStaticFiles(r *gin.Engine) {
|
||||
|
||||
func setupSystray(port int) {
|
||||
systray.Run(func() {
|
||||
icon, err := embedfs.Assets.ReadFile("assets/icon.png")
|
||||
var icon []byte
|
||||
var err error
|
||||
if runtime.GOOS == "windows" {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.ico")
|
||||
} else {
|
||||
icon, err = embedfs.Assets.ReadFile("assets/icon.png")
|
||||
}
|
||||
if err != nil {
|
||||
zapLogger.Error("无法加载托盘图标", zap.String("error", err.Error()))
|
||||
zapLogger.Error("无法加载托盘图标", zap.Error(err))
|
||||
}
|
||||
systray.SetIcon(icon)
|
||||
systray.SetTitle("Nex Gateway")
|
||||
systray.SetTooltip("AI Gateway")
|
||||
systray.SetTooltip(appTooltip)
|
||||
|
||||
mOpen := systray.AddMenuItem("打开管理界面", "在浏览器中打开")
|
||||
systray.AddSeparator()
|
||||
@@ -271,17 +312,15 @@ func setupSystray(port int) {
|
||||
mPort := systray.AddMenuItem(fmt.Sprintf("端口: %d", port), "")
|
||||
mPort.Disable()
|
||||
systray.AddSeparator()
|
||||
mAbout := systray.AddMenuItem("关于", "")
|
||||
systray.AddSeparator()
|
||||
mQuit := systray.AddMenuItem("退出", "停止服务并退出")
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
openBrowser(fmt.Sprintf("http://localhost:%d", port))
|
||||
case <-mAbout.ClickedCh:
|
||||
showAbout()
|
||||
if err := openBrowser(fmt.Sprintf("http://localhost:%d", port)); err != nil {
|
||||
zapLogger.Warn("打开浏览器失败", zap.Error(err))
|
||||
}
|
||||
case <-mQuit.ClickedCh:
|
||||
doShutdown()
|
||||
systray.Quit()
|
||||
@@ -300,7 +339,9 @@ func doShutdown() {
|
||||
if server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
server.Shutdown(ctx)
|
||||
if err := server.Shutdown(ctx); err != nil && zapLogger != nil {
|
||||
zapLogger.Warn("关闭服务器失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if shutdownCancel != nil {
|
||||
@@ -338,8 +379,8 @@ func (s *SingletonLock) Lock() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SingletonLock) Unlock() {
|
||||
s.flock.Unlock()
|
||||
func (s *SingletonLock) Unlock() error {
|
||||
return s.flock.Unlock()
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
@@ -366,28 +407,3 @@ func openBrowser(url string) error {
|
||||
|
||||
return cmd.Start()
|
||||
}
|
||||
|
||||
func showError(title, message string) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "%s"`, message, title)
|
||||
exec.Command("osascript", "-e", script).Run()
|
||||
case "windows":
|
||||
exec.Command("msg", "*", message).Run()
|
||||
case "linux":
|
||||
exec.Command("zenity", "--error", fmt.Sprintf("--title=%s", title), fmt.Sprintf("--text=%s", message)).Run()
|
||||
}
|
||||
}
|
||||
|
||||
func showAbout() {
|
||||
message := "Nex Gateway\n\nAI Gateway - 统一的大模型 API 网关\n\nhttps://github.com/nex/gateway"
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
script := fmt.Sprintf(`display dialog "%s" buttons {"OK"} default button "OK" with title "关于 Nex Gateway"`, message)
|
||||
exec.Command("osascript", "-e", script).Run()
|
||||
case "windows":
|
||||
exec.Command("msg", "*", message).Run()
|
||||
case "linux":
|
||||
exec.Command("zenity", "--info", "--title=关于 Nex Gateway", fmt.Sprintf("--text=%s", message)).Run()
|
||||
}
|
||||
}
|
||||
|
||||
61
backend/cmd/desktop/messagebox_test.go
Normal file
61
backend/cmd/desktop/messagebox_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func withMessageBoxW(t *testing.T, fn func(hwnd, text, caption, flags uintptr) (uintptr, error)) {
|
||||
t.Helper()
|
||||
|
||||
old := callMessageBoxW
|
||||
callMessageBoxW = fn
|
||||
t.Cleanup(func() {
|
||||
callMessageBoxW = old
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_InvalidUTF16(t *testing.T) {
|
||||
err := messageBox("bad\x00title", "测试消息", mbIconInformation)
|
||||
if err == nil {
|
||||
t.Fatal("包含 NUL 字符时应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_SuccessIgnoresLastError(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 1, syscall.Errno(123)
|
||||
})
|
||||
|
||||
if err := messageBox("测试标题", "测试消息", mbIconInformation); err != nil {
|
||||
t.Fatalf("MessageBoxW 返回成功时应忽略 last error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBoxW_WindowsOnly_FailureUsesReturnValue(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
err := messageBox("测试标题", "测试消息", mbIconInformation)
|
||||
if !errors.Is(err, syscall.Errno(5)) {
|
||||
t.Fatalf("MessageBoxW 返回 0 时应返回调用错误: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowError_WindowsBranch(t *testing.T) {
|
||||
withMessageBoxW(t, func(_, _, _, _ uintptr) (uintptr, error) {
|
||||
return 0, syscall.Errno(5)
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
t.Fatalf("showError 不应因 MessageBoxW 失败而 panic: %v", recovered)
|
||||
}
|
||||
}()
|
||||
|
||||
showError("测试错误", "这是一条测试错误消息")
|
||||
}
|
||||
9
backend/cmd/desktop/metadata.go
Normal file
9
backend/cmd/desktop/metadata.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package main
|
||||
|
||||
const (
|
||||
appName = "Nex"
|
||||
appTooltip = appName
|
||||
appDescription = "AI Gateway - 统一的大模型 API 网关"
|
||||
// #nosec G101 -- 项目官网地址不是凭据
|
||||
appWebsite = "https://github.com/nex/gateway"
|
||||
)
|
||||
13
backend/cmd/desktop/metadata_test.go
Normal file
13
backend/cmd/desktop/metadata_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDesktopMetadata(t *testing.T) {
|
||||
if appName != "Nex" {
|
||||
t.Fatalf("appName = %q, want %q", appName, "Nex")
|
||||
}
|
||||
|
||||
if appTooltip != appName {
|
||||
t.Fatalf("appTooltip = %q, want %q", appTooltip, appName)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
@@ -21,19 +22,12 @@ func TestCheckPortAvailable(t *testing.T) {
|
||||
func TestCheckPortOccupied(t *testing.T) {
|
||||
port := 19827
|
||||
|
||||
listener, err := net.Listen("tcp", ":19827")
|
||||
listener, err := net.Listen("tcp", ":19827") //nolint:gosec // 需要验证 checkPortAvailable 对通配地址占用的检测行为
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = checkPortAvailable(port)
|
||||
@@ -47,13 +41,19 @@ func TestCheckPortOccupied(t *testing.T) {
|
||||
func TestCheckPortAvailableAfterClose(t *testing.T) {
|
||||
port := 19828
|
||||
|
||||
listener, err := net.Listen("tcp", ":19828")
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:19828")
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动测试服务器: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{}
|
||||
go server.Serve(listener)
|
||||
server := &http.Server{ReadHeaderTimeout: time.Second}
|
||||
defer server.Close()
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
if err != nil && err != http.ErrServerClosed && !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("serve failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
|
||||
44
backend/cmd/desktop/routes_test.go
Normal file
44
backend/cmd/desktop/routes_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"nex/backend/internal/handler"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupRoutes_VersionDoesNotFallback(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
|
||||
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
if contentType := w.Header().Get("Content-Type"); contentType == "text/html; charset=utf-8" {
|
||||
t.Fatalf("版本接口不应返回 SPA fallback HTML")
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
for _, key := range []string{"version", "commit", "build_time"} {
|
||||
if result[key] == "" {
|
||||
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,11 @@ func TestSingletonLock_FirstLockSuccess(t *testing.T) {
|
||||
if err := lock.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功,但返回错误: %v", err)
|
||||
}
|
||||
defer lock.Unlock()
|
||||
defer func() {
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||
@@ -25,12 +29,18 @@ func TestSingletonLock_DuplicateLockFails(t *testing.T) {
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
defer lock1.Unlock()
|
||||
defer func() {
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
err := lock2.Lock()
|
||||
if err == nil {
|
||||
lock2.Unlock()
|
||||
if unlockErr := lock2.Unlock(); unlockErr != nil {
|
||||
t.Fatalf("解锁失败: %v", unlockErr)
|
||||
}
|
||||
t.Fatal("重复加锁应失败,但返回 nil")
|
||||
}
|
||||
}
|
||||
@@ -43,16 +53,22 @@ func TestSingletonLock_UnlockThenRelock(t *testing.T) {
|
||||
if err := lock1.Lock(); err != nil {
|
||||
t.Fatalf("首次加锁应成功: %v", err)
|
||||
}
|
||||
lock1.Unlock()
|
||||
if err := lock1.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
|
||||
lock2 := NewSingletonLock(lockPath)
|
||||
if err := lock2.Lock(); err != nil {
|
||||
t.Fatalf("释放后重新加锁应成功: %v", err)
|
||||
}
|
||||
lock2.Unlock()
|
||||
if err := lock2.Unlock(); err != nil {
|
||||
t.Fatalf("解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingletonLock_UnlockWithoutLock(t *testing.T) {
|
||||
lock := NewSingletonLock(filepath.Join(os.TempDir(), "nex-gateway-test-nil.lock"))
|
||||
lock.Unlock()
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("未加锁时解锁失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,73 +1,26 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/embedfs"
|
||||
)
|
||||
|
||||
func TestSetupStaticFiles(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := fs.Sub(embedfs.FrontendDist, "frontend-dist")
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
getContentType := func(path string) string {
|
||||
if strings.HasSuffix(path, ".js") {
|
||||
return "application/javascript"
|
||||
}
|
||||
if strings.HasSuffix(path, ".css") {
|
||||
return "text/css"
|
||||
}
|
||||
if strings.HasSuffix(path, ".svg") {
|
||||
return "image/svg+xml"
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/assets/*filepath", func(c *gin.Context) {
|
||||
filepath := c.Param("filepath")
|
||||
data, err := fs.ReadFile(distFS, "assets"+filepath)
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, getContentType(filepath), data)
|
||||
})
|
||||
|
||||
r.GET("/favicon.svg", func(c *gin.Context) {
|
||||
data, err := fs.ReadFile(distFS, "favicon.svg")
|
||||
if err != nil {
|
||||
c.Status(404)
|
||||
return
|
||||
}
|
||||
c.Data(200, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/health") {
|
||||
c.JSON(404, gin.H{"error": "not found"})
|
||||
return
|
||||
}
|
||||
data, err := fs.ReadFile(distFS, "index.html")
|
||||
if err != nil {
|
||||
c.Status(500)
|
||||
return
|
||||
}
|
||||
c.Data(200, "text/html; charset=utf-8", data)
|
||||
})
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("API 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
@@ -79,6 +32,32 @@ func TestSetupStaticFiles(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenAI proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic proxy prefix 404", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/anthropic/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码 404, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "not found") {
|
||||
t.Errorf("期望返回 API 风格错误,实际 %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -121,3 +100,139 @@ func TestSetupStaticFiles(t *testing.T) {
|
||||
|
||||
t.Log("静态文件服务测试通过")
|
||||
}
|
||||
|
||||
func TestSetupStaticFilesWithFS_IconPNG(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
setupStaticFilesWithFS(r, fstest.MapFS{
|
||||
"icon.png": {Data: []byte("png")},
|
||||
"index.html": {Data: []byte("<html>fallback</html>")},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/icon.png", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "image/png" {
|
||||
t.Fatalf("期望 Content-Type image/png, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
if w.Body.String() != "png" {
|
||||
t.Fatalf("期望返回 PNG 内容,实际 %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithProtocolAndStaticRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
distFS, err := frontendDistFS()
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试: 前端资源未构建: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var gotProtocol string
|
||||
var gotPath string
|
||||
r.Any("/openai/*path", withProtocol("openai", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
r.Any("/anthropic/*path", withProtocol("anthropic", func(c *gin.Context) {
|
||||
gotProtocol = c.Param("protocol")
|
||||
gotPath = c.Param("path")
|
||||
c.JSON(http.StatusOK, gin.H{"protocol": gotProtocol, "path": gotPath})
|
||||
}))
|
||||
setupStaticFilesWithFS(r, distFS)
|
||||
|
||||
t.Run("OpenAI route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/openai/v1/chat/completions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" {
|
||||
t.Errorf("期望 protocol=openai, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/chat/completions" {
|
||||
t.Errorf("期望 path=/v1/chat/completions, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Anthropic route enters proxy handler wrapper", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("POST", "/anthropic/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "anthropic" {
|
||||
t.Errorf("期望 protocol=anthropic, 实际 %s", gotProtocol)
|
||||
}
|
||||
if gotPath != "/v1/messages" {
|
||||
t.Errorf("期望 path=/v1/messages, 实际 %s", gotPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Static assets are not hijacked", func(t *testing.T) {
|
||||
gotProtocol = ""
|
||||
gotPath = ""
|
||||
|
||||
req := httptest.NewRequest("GET", "/assets/test.js", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if gotProtocol != "" || gotPath != "" {
|
||||
t.Errorf("静态资源不应进入代理包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
if w.Code == http.StatusOK {
|
||||
if !strings.HasPrefix(w.Header().Get("Content-Type"), "application/javascript") {
|
||||
t.Errorf("期望 JS Content-Type, 实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
return
|
||||
}
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望静态资源返回 200 或 404, 实际 %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SPA path keeps fallback", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码 200, 实际 %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Header().Get("Content-Type"), "text/html") {
|
||||
t.Errorf("期望返回 HTML,实际 %s", w.Header().Get("Content-Type"))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unknown proxy-like path does not return index html", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/openai/unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("显式代理路由应进入代理包装器,实际状态码 %d", w.Code)
|
||||
}
|
||||
if gotProtocol != "openai" || gotPath != "/unknown" {
|
||||
t.Errorf("期望 unknown 代理路径进入 openai 包装器,实际 protocol=%s path=%s", gotProtocol, gotPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -23,18 +22,19 @@ import (
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/internal/service"
|
||||
"nex/backend/pkg/buildinfo"
|
||||
pkgLogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
minimalLogger := pkgLogger.NewMinimal()
|
||||
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
minimalLogger.Fatal("加载配置失败", zap.Error(err))
|
||||
}
|
||||
|
||||
cfg.PrintSummary()
|
||||
|
||||
zapLogger, err := pkgLogger.New(pkgLogger.Config{
|
||||
zapLogger, err := pkgLogger.Upgrade(minimalLogger, pkgLogger.Config{
|
||||
Level: cfg.Log.Level,
|
||||
Path: cfg.Log.Path,
|
||||
MaxSize: cfg.Log.MaxSize,
|
||||
@@ -43,13 +43,19 @@ func main() {
|
||||
Compress: cfg.Log.Compress,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("初始化日志失败: %v", err)
|
||||
minimalLogger.Fatal("初始化日志失败", zap.Error(err))
|
||||
}
|
||||
defer zapLogger.Sync()
|
||||
defer func() {
|
||||
if err := zapLogger.Sync(); err != nil {
|
||||
minimalLogger.Warn("同步日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
cfg.PrintSummary(zapLogger)
|
||||
|
||||
db, err := database.Init(&cfg.Database, zapLogger)
|
||||
if err != nil {
|
||||
zapLogger.Fatal("初始化数据库失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("初始化数据库失败", zap.Error(err))
|
||||
}
|
||||
defer database.Close(db)
|
||||
|
||||
@@ -74,19 +80,20 @@ func main() {
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
if err := registry.Register(openai.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 OpenAI 适配器失败", zap.Error(err))
|
||||
}
|
||||
if err := registry.Register(anthropic.NewAdapter()); err != nil {
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("注册 Anthropic 适配器失败", zap.Error(err))
|
||||
}
|
||||
engine := conversion.NewConversionEngine(registry, zapLogger)
|
||||
|
||||
providerClient := provider.NewClient()
|
||||
providerClient := provider.NewClient(zapLogger)
|
||||
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService)
|
||||
proxyHandler := handler.NewProxyHandler(engine, providerClient, routingService, providerService, statsService, zapLogger)
|
||||
providerHandler := handler.NewProviderHandler(providerService)
|
||||
modelHandler := handler.NewModelHandler(modelService)
|
||||
statsHandler := handler.NewStatsHandler(statsService)
|
||||
versionHandler := handler.NewVersionHandler()
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
@@ -96,7 +103,7 @@ func main() {
|
||||
r.Use(middleware.Logging(zapLogger))
|
||||
r.Use(middleware.CORS())
|
||||
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler)
|
||||
setupRoutes(r, proxyHandler, providerHandler, modelHandler, statsHandler, versionHandler)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Server.Port),
|
||||
@@ -106,9 +113,13 @@ func main() {
|
||||
}
|
||||
|
||||
go func() {
|
||||
zapLogger.Info("AI Gateway 启动", zap.String("addr", srv.Addr))
|
||||
zapLogger.Info("AI Gateway 启动",
|
||||
zap.String("addr", srv.Addr),
|
||||
zap.String("version", buildinfo.Version()),
|
||||
zap.String("commit", buildinfo.Commit()),
|
||||
zap.String("build_time", buildinfo.BuildTime()))
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
zapLogger.Fatal("服务器启动失败", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -122,7 +133,7 @@ func main() {
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
zapLogger.Fatal("服务器强制关闭", zap.String("error", err.Error()))
|
||||
zapLogger.Fatal("服务器强制关闭", zap.Error(err))
|
||||
}
|
||||
|
||||
statsBuffer.Stop()
|
||||
@@ -130,8 +141,9 @@ func main() {
|
||||
zapLogger.Info("服务器已关闭")
|
||||
}
|
||||
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler) {
|
||||
func setupRoutes(r *gin.Engine, proxyHandler *handler.ProxyHandler, providerHandler *handler.ProviderHandler, modelHandler *handler.ModelHandler, statsHandler *handler.StatsHandler, versionHandler *handler.VersionHandler) {
|
||||
r.Any("/:protocol/*path", proxyHandler.HandleProxy)
|
||||
r.GET("/api/version", versionHandler.GetVersion)
|
||||
|
||||
providers := r.Group("/api/providers")
|
||||
{
|
||||
|
||||
37
backend/cmd/server/routes_test.go
Normal file
37
backend/cmd/server/routes_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/handler"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestSetupRoutes_Version(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
setupRoutes(r, &handler.ProxyHandler{}, &handler.ProviderHandler{}, &handler.ModelHandler{}, &handler.StatsHandler{}, handler.NewVersionHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var result map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
for _, key := range []string{"version", "commit", "build_time"} {
|
||||
if result[key] == "" {
|
||||
t.Fatalf("响应缺少 %s 字段: %#v", key, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
@@ -57,7 +59,10 @@ type LogConfig struct {
|
||||
// DefaultConfig returns default config values
|
||||
func DefaultConfig() *Config {
|
||||
// Use home dir for default paths
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
return &Config{
|
||||
@@ -96,7 +101,7 @@ func GetConfigDir() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
configDir := filepath.Join(homeDir, ".nex")
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return configDir, nil
|
||||
@@ -122,7 +127,10 @@ func GetConfigPath() (string, error) {
|
||||
|
||||
// setupDefaults 设置默认配置值
|
||||
func setupDefaults(v *viper.Viper) {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
nexDir := filepath.Join(homeDir, ".nex")
|
||||
|
||||
v.SetDefault("server.port", 9826)
|
||||
@@ -176,27 +184,33 @@ func setupFlags(v *viper.Viper, flagSet *pflag.FlagSet) {
|
||||
|
||||
// 绑定所有 flag 到 viper
|
||||
// 注意:必须在设置默认值之后绑定
|
||||
v.BindPFlag("server.port", flagSet.Lookup("server-port"))
|
||||
v.BindPFlag("server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
v.BindPFlag("server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
bindPFlag(v, "server.port", flagSet.Lookup("server-port"))
|
||||
bindPFlag(v, "server.read_timeout", flagSet.Lookup("server-read-timeout"))
|
||||
bindPFlag(v, "server.write_timeout", flagSet.Lookup("server-write-timeout"))
|
||||
|
||||
v.BindPFlag("database.driver", flagSet.Lookup("database-driver"))
|
||||
v.BindPFlag("database.path", flagSet.Lookup("database-path"))
|
||||
v.BindPFlag("database.host", flagSet.Lookup("database-host"))
|
||||
v.BindPFlag("database.port", flagSet.Lookup("database-port"))
|
||||
v.BindPFlag("database.user", flagSet.Lookup("database-user"))
|
||||
v.BindPFlag("database.password", flagSet.Lookup("database-password"))
|
||||
v.BindPFlag("database.dbname", flagSet.Lookup("database-dbname"))
|
||||
v.BindPFlag("database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
v.BindPFlag("database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
v.BindPFlag("database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
bindPFlag(v, "database.driver", flagSet.Lookup("database-driver"))
|
||||
bindPFlag(v, "database.path", flagSet.Lookup("database-path"))
|
||||
bindPFlag(v, "database.host", flagSet.Lookup("database-host"))
|
||||
bindPFlag(v, "database.port", flagSet.Lookup("database-port"))
|
||||
bindPFlag(v, "database.user", flagSet.Lookup("database-user"))
|
||||
bindPFlag(v, "database.password", flagSet.Lookup("database-password"))
|
||||
bindPFlag(v, "database.dbname", flagSet.Lookup("database-dbname"))
|
||||
bindPFlag(v, "database.max_idle_conns", flagSet.Lookup("database-max-idle-conns"))
|
||||
bindPFlag(v, "database.max_open_conns", flagSet.Lookup("database-max-open-conns"))
|
||||
bindPFlag(v, "database.conn_max_lifetime", flagSet.Lookup("database-conn-max-lifetime"))
|
||||
|
||||
v.BindPFlag("log.level", flagSet.Lookup("log-level"))
|
||||
v.BindPFlag("log.path", flagSet.Lookup("log-path"))
|
||||
v.BindPFlag("log.max_size", flagSet.Lookup("log-max-size"))
|
||||
v.BindPFlag("log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
v.BindPFlag("log.max_age", flagSet.Lookup("log-max-age"))
|
||||
v.BindPFlag("log.compress", flagSet.Lookup("log-compress"))
|
||||
bindPFlag(v, "log.level", flagSet.Lookup("log-level"))
|
||||
bindPFlag(v, "log.path", flagSet.Lookup("log-path"))
|
||||
bindPFlag(v, "log.max_size", flagSet.Lookup("log-max-size"))
|
||||
bindPFlag(v, "log.max_backups", flagSet.Lookup("log-max-backups"))
|
||||
bindPFlag(v, "log.max_age", flagSet.Lookup("log-max-age"))
|
||||
bindPFlag(v, "log.compress", flagSet.Lookup("log-compress"))
|
||||
}
|
||||
|
||||
func bindPFlag(v *viper.Viper, key string, flag *pflag.Flag) {
|
||||
if err := v.BindPFlag(key, flag); err != nil {
|
||||
panic(fmt.Sprintf("绑定 flag %s 失败: %v", key, err))
|
||||
}
|
||||
}
|
||||
|
||||
// setupEnv 绑定环境变量
|
||||
@@ -217,10 +231,17 @@ func setupConfigFile(v *viper.Viper, configPath string) error {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
// 配置文件不存在,创建默认配置文件
|
||||
if err := v.SafeWriteConfig(); err != nil {
|
||||
// 忽略写入错误(可能目录已存在等)
|
||||
writeErr := v.SafeWriteConfigAs(configPath)
|
||||
if writeErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var alreadyExistsErr viper.ConfigFileAlreadyExistsError
|
||||
if errors.As(writeErr, &alreadyExistsErr) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return appErrors.Wrap(appErrors.ErrInternal, writeErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -245,7 +266,9 @@ func LoadConfigFromPath(configPath string) (*Config, error) {
|
||||
setupFlags(v, flagSet)
|
||||
|
||||
// 3. 解析 CLI 参数(忽略错误,因为可能没有参数)
|
||||
flagSet.Parse(os.Args[1:])
|
||||
if err := flagSet.Parse(os.Args[1:]); err != nil {
|
||||
return nil, appErrors.Wrap(appErrors.ErrInvalidRequest, err)
|
||||
}
|
||||
|
||||
// 4. 获取配置文件路径(可能被 --config 参数覆盖)
|
||||
if configPathFlag, err := flagSet.GetString("config"); err == nil && configPathFlag != "" {
|
||||
@@ -294,11 +317,11 @@ func SaveConfig(cfg *Config) error {
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return appErrors.Wrap(appErrors.ErrInternal, err)
|
||||
}
|
||||
|
||||
return os.WriteFile(configPath, data, 0600)
|
||||
return os.WriteFile(configPath, data, 0o600)
|
||||
}
|
||||
|
||||
// Validate validates the config
|
||||
@@ -311,22 +334,24 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
|
||||
// PrintSummary 打印配置摘要
|
||||
func (c *Config) PrintSummary() {
|
||||
fmt.Println("\nAI Gateway 启动配置")
|
||||
fmt.Println("==================")
|
||||
fmt.Printf("服务器端口: %d\n", c.Server.Port)
|
||||
func (c *Config) PrintSummary(logger *zap.Logger) {
|
||||
logger.Info("AI Gateway 启动配置",
|
||||
zap.Int("server_port", c.Server.Port),
|
||||
zap.String("database_driver", c.Database.Driver),
|
||||
zap.String("log_level", c.Log.Level),
|
||||
)
|
||||
|
||||
if c.Database.Driver == "mysql" {
|
||||
fmt.Printf("数据库类型: mysql\n")
|
||||
fmt.Printf("数据库地址: %s:%d/%s\n", c.Database.Host, c.Database.Port, c.Database.DBName)
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "mysql"),
|
||||
zap.String("host", c.Database.Host),
|
||||
zap.Int("port", c.Database.Port),
|
||||
zap.String("database", c.Database.DBName),
|
||||
)
|
||||
} else {
|
||||
fmt.Printf("数据库类型: sqlite\n")
|
||||
fmt.Printf("数据库路径: %s\n", c.Database.Path)
|
||||
logger.Info("数据库配置",
|
||||
zap.String("driver", "sqlite"),
|
||||
zap.String("path", c.Database.Path),
|
||||
)
|
||||
}
|
||||
fmt.Printf("日志级别: %s\n", c.Log.Level)
|
||||
fmt.Println("\n配置来源:")
|
||||
configPath, _ := GetConfigPath()
|
||||
fmt.Printf(" 配置文件: %s\n", configPath)
|
||||
fmt.Println(" 环境变量: 待统计")
|
||||
fmt.Println(" CLI 参数: 待统计")
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -171,7 +172,9 @@ func TestConfig_Validate(t *testing.T) {
|
||||
err := cfg.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -233,7 +236,7 @@ func TestSaveAndLoadConfig(t *testing.T) {
|
||||
configPath := filepath.Join(dir, "config.yaml")
|
||||
data, err := yaml.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
err = os.WriteFile(configPath, data, 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 加载配置
|
||||
@@ -302,7 +305,7 @@ func TestPrintSummary(t *testing.T) {
|
||||
t.Run("SQLite模式摘要", func(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary()
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
t.Run("MySQL模式摘要", func(t *testing.T) {
|
||||
@@ -313,7 +316,7 @@ func TestPrintSummary(t *testing.T) {
|
||||
cfg.Database.User = "nex"
|
||||
cfg.Database.DBName = "nex"
|
||||
assert.NotPanics(t, func() {
|
||||
cfg.PrintSummary()
|
||||
cfg.PrintSummary(zap.NewNop())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -29,8 +29,8 @@ type Model struct {
|
||||
// UsageStats 用量统计
|
||||
type UsageStats struct {
|
||||
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
|
||||
ProviderID string `gorm:"not null;index" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index" json:"model_name"`
|
||||
ProviderID string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"provider_id"`
|
||||
ModelName string `gorm:"not null;index;uniqueIndex:idx_provider_model_date" json:"model_name"`
|
||||
RequestCount int `gorm:"default:0" json:"request_count"`
|
||||
Date time.Time `gorm:"type:date;not null;uniqueIndex:idx_provider_model_date" json:"date"`
|
||||
}
|
||||
@@ -47,4 +47,3 @@ func (Model) TableName() string {
|
||||
func (UsageStats) TableName() string {
|
||||
return "usage_stats"
|
||||
}
|
||||
|
||||
|
||||
@@ -141,7 +141,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Message: err.Message,
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"type":"error","error":{"type":"internal_error","message":"internal error"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -235,7 +238,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
@@ -269,7 +276,11 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat:
|
||||
// Chat 响应必须有 model 字段,存在则改写,不存在则添加
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
return body, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
@@ -48,6 +49,28 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_APIReferenceNativePaths(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
// docs/api_reference/anthropic defines messages and models under /v1.
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/v1/messages", conversion.InterfaceTypeChat},
|
||||
{"/v1/models", conversion.InterfaceTypeModels},
|
||||
{"/v1/models/claude-sonnet-4-5", conversion.InterfaceTypeModelInfo},
|
||||
{"/messages", conversion.InterfaceTypePassthrough},
|
||||
{"/models", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
@@ -141,8 +164,8 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
t.Run("解码嵌入请求", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -150,24 +173,24 @@ func TestAdapter_UnsupportedEmbedding(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeEmbeddingRequest(&canonical.CanonicalEmbeddingRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.True(t, errors.As(err, &convErr))
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.DecodeEmbeddingResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码嵌入响应", func(t *testing.T) {
|
||||
_, err := a.EncodeEmbeddingResponse(&canonical.CanonicalEmbeddingResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
@@ -178,8 +201,8 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
t.Run("解码重排序请求", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankRequest([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
@@ -187,24 +210,24 @@ func TestAdapter_UnsupportedRerank(t *testing.T) {
|
||||
provider := conversion.NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
_, err := a.EncodeRerankRequest(&canonical.CanonicalRerankRequest{}, provider)
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("解码重排序响应", func(t *testing.T) {
|
||||
_, err := a.DecodeRerankResponse([]byte(`{}`))
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
|
||||
t.Run("编码重排序响应", func(t *testing.T) {
|
||||
_, err := a.EncodeRerankResponse(&canonical.CanonicalRerankResponse{})
|
||||
require.Error(t, err)
|
||||
convErr, ok := err.(*conversion.ConversionError)
|
||||
require.True(t, ok)
|
||||
var convErr *conversion.ConversionError
|
||||
require.ErrorAs(t, err, &convErr)
|
||||
assert.Equal(t, conversion.ErrorCodeInterfaceNotSupported, convErr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -28,7 +28,10 @@ func decodeRequest(body []byte) (*canonical.CanonicalRequest, error) {
|
||||
|
||||
var canonicalMsgs []canonical.CanonicalMessage
|
||||
for _, msg := range req.Messages {
|
||||
decoded := decodeMessage(msg)
|
||||
decoded, err := decodeMessage(msg)
|
||||
if err != nil {
|
||||
return nil, conversion.NewConversionError(conversion.ErrorCodeJSONParseError, "解析消息内容失败").WithCause(err)
|
||||
}
|
||||
canonicalMsgs = append(canonicalMsgs, decoded...)
|
||||
}
|
||||
|
||||
@@ -94,10 +97,13 @@ func decodeSystem(system any) any {
|
||||
}
|
||||
|
||||
// decodeMessage 解码 Anthropic 消息
|
||||
func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
func decodeMessage(msg Message) ([]canonical.CanonicalMessage, error) {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var toolResults []canonical.ContentBlock
|
||||
var others []canonical.ContentBlock
|
||||
for _, b := range blocks {
|
||||
@@ -117,58 +123,83 @@ func decodeMessage(msg Message) []canonical.CanonicalMessage {
|
||||
if len(result) == 0 {
|
||||
result = append(result, canonical.CanonicalMessage{Role: canonical.RoleUser, Content: []canonical.ContentBlock{canonical.NewTextBlock("")}})
|
||||
}
|
||||
return result
|
||||
return result, nil
|
||||
|
||||
case "assistant":
|
||||
blocks := decodeContentBlocks(msg.Content)
|
||||
blocks, err := decodeContentBlocks(msg.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
blocks = append(blocks, canonical.NewTextBlock(""))
|
||||
}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}
|
||||
return []canonical.CanonicalMessage{{Role: canonical.RoleAssistant, Content: blocks}}, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeContentBlocks 解码内容块列表
|
||||
func decodeContentBlocks(content any) []canonical.ContentBlock {
|
||||
func decodeContentBlocks(content any) ([]canonical.ContentBlock, error) {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(v)}, nil
|
||||
case []any:
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
block := decodeSingleContentBlock(m)
|
||||
block, err := decodeSingleContentBlock(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if block != nil {
|
||||
blocks = append(blocks, *block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(blocks) > 0 {
|
||||
return blocks
|
||||
return blocks, nil
|
||||
}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
case nil:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock("")}, nil
|
||||
default:
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}
|
||||
return []canonical.ContentBlock{canonical.NewTextBlock(fmt.Sprintf("%v", v))}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// decodeSingleContentBlock 解码单个内容块
|
||||
func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
t, _ := m["type"].(string)
|
||||
func decodeSingleContentBlock(m map[string]any) (*canonical.ContentBlock, error) {
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "text", Text: text}, nil
|
||||
case "tool_use":
|
||||
id, _ := m["id"].(string)
|
||||
name, _ := m["name"].(string)
|
||||
input, _ := json.Marshal(m["input"])
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}
|
||||
id, ok := m["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
name, ok := m["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
input, err := json.Marshal(m["input"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "tool_use", ID: id, Name: name, Input: input}, nil
|
||||
case "tool_result":
|
||||
toolUseID, _ := m["tool_use_id"].(string)
|
||||
toolUseID, ok := m["tool_use_id"].(string)
|
||||
if !ok {
|
||||
toolUseID = ""
|
||||
}
|
||||
isErr := false
|
||||
if ie, ok := m["is_error"].(bool); ok {
|
||||
isErr = ie
|
||||
@@ -179,7 +210,11 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
case string:
|
||||
content = json.RawMessage(fmt.Sprintf("%q", cv))
|
||||
default:
|
||||
content, _ = json.Marshal(cv)
|
||||
encoded, err := json.Marshal(cv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content = encoded
|
||||
}
|
||||
} else {
|
||||
content = json.RawMessage(`""`)
|
||||
@@ -189,15 +224,18 @@ func decodeSingleContentBlock(m map[string]any) *canonical.ContentBlock {
|
||||
ToolUseID: toolUseID,
|
||||
Content: content,
|
||||
IsError: &isErr,
|
||||
}
|
||||
}, nil
|
||||
case "thinking":
|
||||
thinking, _ := m["thinking"].(string)
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}
|
||||
thinking, ok := m["thinking"].(string)
|
||||
if !ok {
|
||||
thinking = ""
|
||||
}
|
||||
return &canonical.ContentBlock{Type: "thinking", Thinking: thinking}, nil
|
||||
case "redacted_thinking":
|
||||
// 丢弃
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// decodeTools 解码工具定义
|
||||
@@ -232,7 +270,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "auto":
|
||||
return canonical.NewToolChoiceAuto()
|
||||
@@ -241,7 +282,10 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
case "any":
|
||||
return canonical.NewToolChoiceAny()
|
||||
case "tool":
|
||||
name, _ := v["name"].(string)
|
||||
name, ok := v["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,8 @@ func TestEncodeRequest_Basic(t *testing.T) {
|
||||
assert.Equal(t, true, result["stream"])
|
||||
assert.Equal(t, float64(1024), result["max_tokens"])
|
||||
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
}
|
||||
|
||||
@@ -55,17 +56,20 @@ func TestEncodeRequest_ToolMergeIntoUser(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
|
||||
// tool 消息应被合并到相邻 user 消息
|
||||
foundToolResult := false
|
||||
for _, m := range msgs {
|
||||
msgMap := m.(map[string]any)
|
||||
msgMap, ok := m.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if msgMap["role"] == "user" {
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if ok {
|
||||
for _, c := range content {
|
||||
block := c.(map[string]any)
|
||||
block, ok := c.(map[string]any)
|
||||
require.True(t, ok)
|
||||
if block["type"] == "tool_result" {
|
||||
foundToolResult = true
|
||||
}
|
||||
@@ -93,8 +97,10 @@ func TestEncodeRequest_FirstUserGuarantee(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", firstMsg["role"])
|
||||
}
|
||||
|
||||
@@ -140,9 +146,11 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "assistant", result["role"])
|
||||
assert.Equal(t, "end_turn", result["stop_reason"])
|
||||
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", block["type"])
|
||||
assert.Equal(t, "你好", block["text"])
|
||||
}
|
||||
@@ -160,10 +168,12 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
data := result["data"].([]any)
|
||||
data, ok := result["data"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, data, 1)
|
||||
|
||||
model := data[0].(map[string]any)
|
||||
model, ok := data[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude-3-opus", model["id"])
|
||||
// created 应为 RFC3339 格式
|
||||
createdAt, ok := model["created_at"].(string)
|
||||
@@ -280,11 +290,14 @@ func TestEncodeRequest_ConsecutiveRoleMerge(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 1)
|
||||
userMsg := msgs[0].(map[string]any)
|
||||
userMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "user", userMsg["role"])
|
||||
content := userMsg["content"].([]any)
|
||||
content, ok := userMsg["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 2)
|
||||
}
|
||||
|
||||
@@ -321,7 +334,8 @@ func TestEncodeResponse_ReasoningTokens(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, ok := result["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasReasoning := usage["reasoning_tokens"]
|
||||
assert.False(t, hasReasoning)
|
||||
}
|
||||
@@ -341,9 +355,11 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
content := result["content"].([]any)
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, content, 1)
|
||||
block := content[0].(map[string]any)
|
||||
block, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", block["type"])
|
||||
assert.Equal(t, "tool_1", block["id"])
|
||||
assert.Equal(t, "search", block["name"])
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewStreamDecoder() *StreamDecoder {
|
||||
func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
data := rawChunk
|
||||
if len(d.utf8Remainder) > 0 {
|
||||
data = append(d.utf8Remainder, rawChunk...)
|
||||
data = append(append([]byte{}, d.utf8Remainder...), rawChunk...)
|
||||
d.utf8Remainder = nil
|
||||
}
|
||||
|
||||
@@ -50,9 +50,10 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
switch {
|
||||
case strings.HasPrefix(line, "event: "):
|
||||
eventType = strings.TrimPrefix(line, "event: ")
|
||||
} else if strings.HasPrefix(line, "data: ") {
|
||||
case strings.HasPrefix(line, "data: "):
|
||||
eventData = strings.TrimPrefix(line, "data: ")
|
||||
if eventType != "" && eventData != "" {
|
||||
chunkEvents := d.processEvent(eventType, []byte(eventData))
|
||||
@@ -60,8 +61,8 @@ func (d *StreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStrea
|
||||
}
|
||||
eventType = ""
|
||||
eventData = ""
|
||||
} else if line == "" {
|
||||
// SSE 事件分隔符
|
||||
case line == "":
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,15 +51,23 @@ func (e *StreamEncoder) encodeMessageStart(event canonical.CanonicalStreamEvent)
|
||||
if event.Message != nil {
|
||||
msg := map[string]any{
|
||||
"id": event.Message.ID,
|
||||
"model": event.Message.Model,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{},
|
||||
"model": event.Message.Model,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
}
|
||||
if event.Message.Usage != nil {
|
||||
usage := map[string]any{
|
||||
msg["usage"] = map[string]any{
|
||||
"input_tokens": event.Message.Usage.InputTokens,
|
||||
"output_tokens": event.Message.Usage.OutputTokens,
|
||||
}
|
||||
msg["usage"] = usage
|
||||
} else {
|
||||
msg["usage"] = map[string]any{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
}
|
||||
payload["message"] = msg
|
||||
}
|
||||
@@ -147,6 +155,10 @@ func (e *StreamEncoder) encodeMessageDelta(event canonical.CanonicalStreamEvent)
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": event.Usage.OutputTokens,
|
||||
}
|
||||
} else {
|
||||
payload["usage"] = map[string]any{
|
||||
"output_tokens": 0,
|
||||
}
|
||||
}
|
||||
return e.marshalEvent("message_delta", payload)
|
||||
}
|
||||
|
||||
@@ -21,8 +21,55 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
s := string(chunks[0])
|
||||
assert.True(t, strings.HasPrefix(s, "event: message_start\n"))
|
||||
assert.Contains(t, s, "data: ")
|
||||
assert.Contains(t, s, "msg_1")
|
||||
assert.Contains(t, s, "claude-3")
|
||||
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := payload["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "msg_1", msg["id"])
|
||||
assert.Equal(t, "message", msg["type"])
|
||||
assert.Equal(t, "assistant", msg["role"])
|
||||
assert.Equal(t, []any{}, msg["content"])
|
||||
assert.Equal(t, "claude-3", msg["model"])
|
||||
assert.Nil(t, msg["stop_reason"])
|
||||
assert.Nil(t, msg["stop_sequence"])
|
||||
|
||||
usage, okU := msg["usage"].(map[string]any)
|
||||
require.True(t, okU)
|
||||
assert.Equal(t, float64(0), usage["input_tokens"])
|
||||
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageStart_WithUsage(t *testing.T) {
|
||||
e := NewStreamEncoder()
|
||||
event := canonical.NewMessageStartEventWithUsage("msg_2", "gpt-4", &canonical.CanonicalUsage{InputTokens: 100, OutputTokens: 50})
|
||||
|
||||
chunks := e.EncodeEvent(event)
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
s := string(chunks[0])
|
||||
var payload map[string]any
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "data: ") {
|
||||
require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(l, "data: ")), &payload))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := payload["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, okU := msg["usage"].(map[string]any)
|
||||
require.True(t, okU)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(50), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_ContentBlockDelta(t *testing.T) {
|
||||
@@ -80,7 +127,8 @@ func TestStreamEncoder_ContentBlockStart_Text(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "text", cb["type"])
|
||||
}
|
||||
|
||||
@@ -107,7 +155,8 @@ func TestStreamEncoder_ContentBlockStart_ToolUse(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tool_use", cb["type"])
|
||||
assert.Equal(t, "toolu_1", cb["id"])
|
||||
assert.Equal(t, "search", cb["name"])
|
||||
@@ -131,7 +180,8 @@ func TestStreamEncoder_ContentBlockStart_Thinking(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
cb := payload["content_block"].(map[string]any)
|
||||
cb, ok := payload["content_block"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "thinking", cb["type"])
|
||||
}
|
||||
|
||||
@@ -173,8 +223,13 @@ func TestStreamEncoder_MessageDelta_WithStopReason(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
delta := payload["delta"].(map[string]any)
|
||||
delta, okd := payload["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "end_turn", delta["stop_reason"])
|
||||
|
||||
usage, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku, "message_delta SHALL always include usage")
|
||||
assert.Equal(t, float64(0), usage["output_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
@@ -199,7 +254,8 @@ func TestStreamEncoder_MessageDelta_WithUsage(t *testing.T) {
|
||||
break
|
||||
}
|
||||
}
|
||||
u := payload["usage"].(map[string]any)
|
||||
u, oku := payload["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(88), u["output_tokens"])
|
||||
}
|
||||
|
||||
|
||||
@@ -173,13 +173,15 @@ func TestDecodeMessage_UserWithOnlyToolResults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_Nil(t *testing.T) {
|
||||
blocks := decodeContentBlocks(nil)
|
||||
blocks, err := decodeContentBlocks(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "", blocks[0].Text)
|
||||
}
|
||||
|
||||
func TestDecodeContentBlocks_String(t *testing.T) {
|
||||
blocks := decodeContentBlocks("hello")
|
||||
blocks, err := decodeContentBlocks("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, blocks, 1)
|
||||
assert.Equal(t, "hello", blocks[0].Text)
|
||||
}
|
||||
@@ -217,8 +219,10 @@ func TestEncodeToolChoice(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := encodeToolChoice(tt.choice)
|
||||
assert.Equal(t, tt.want["type"], result.(map[string]any)["type"])
|
||||
assert.Equal(t, tt.want["name"], result.(map[string]any)["name"])
|
||||
r, ok := result.(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tt.want["type"], r["type"])
|
||||
assert.Equal(t, tt.want["name"], r["name"])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -315,12 +319,15 @@ func TestEncodeRequest_WithTools(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
tools := result["tools"].([]any)
|
||||
tools, okt := result["tools"].([]any)
|
||||
require.True(t, okt)
|
||||
assert.Len(t, tools, 1)
|
||||
tool := tools[0].(map[string]any)
|
||||
tool, okt2 := tools[0].(map[string]any)
|
||||
require.True(t, okt2)
|
||||
assert.Equal(t, "search", tool["name"])
|
||||
assert.Equal(t, "Search things", tool["description"])
|
||||
tc := result["tool_choice"].(map[string]any)
|
||||
tc, oktc := result["tool_choice"].(map[string]any)
|
||||
require.True(t, oktc)
|
||||
assert.Equal(t, "auto", tc["type"])
|
||||
}
|
||||
|
||||
@@ -366,7 +373,8 @@ func TestEncodeResponse_UsageWithCacheAndCreation(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["input_tokens"])
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
assert.Equal(t, float64(10), usage["cache_creation_input_tokens"])
|
||||
|
||||
@@ -3,10 +3,14 @@ package conversion
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// HTTPRequestSpec HTTP 请求规格
|
||||
@@ -33,13 +37,10 @@ type ConversionEngine struct {
|
||||
|
||||
// NewConversionEngine 创建转换引擎
|
||||
func NewConversionEngine(registry AdapterRegistry, logger *zap.Logger) *ConversionEngine {
|
||||
if logger == nil {
|
||||
logger = zap.L()
|
||||
}
|
||||
return &ConversionEngine{
|
||||
registry: registry,
|
||||
middlewareChain: NewMiddlewareChain(),
|
||||
logger: logger,
|
||||
logger: pkglogger.WithModule(logger, "conversion.engine"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +73,7 @@ func (e *ConversionEngine) IsPassthrough(clientProtocol, providerProtocol string
|
||||
|
||||
// ConvertHttpRequest 转换 HTTP 请求
|
||||
func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtocol, providerProtocol string, provider *TargetProvider) (*HTTPRequestSpec, error) {
|
||||
nativePath := spec.URL
|
||||
nativePath, rawQuery := splitRequestPath(spec.URL)
|
||||
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
providerAdapter, err := e.registry.Get(providerProtocol)
|
||||
@@ -90,15 +91,18 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
rewrittenBody, err = providerAdapter.RewriteRequestModelName(spec.Body, provider.ModelName, interfaceType)
|
||||
if err != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写请求失败,使用原始请求体",
|
||||
zap.String("error", err.Error()),
|
||||
zap.Error(err),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
rewrittenBody = spec.Body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + nativePath,
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerAdapter.BuildHeaders(provider),
|
||||
Body: rewrittenBody,
|
||||
@@ -115,7 +119,8 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
interfaceType := clientAdapter.DetectInterfaceType(nativePath)
|
||||
providerUrl := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL := providerAdapter.BuildUrl(nativePath, interfaceType)
|
||||
providerURL = appendRawQuery(providerURL, rawQuery)
|
||||
providerHeaders := providerAdapter.BuildHeaders(provider)
|
||||
providerBody, err := e.convertBody(interfaceType, clientAdapter, providerAdapter, provider, spec.Body)
|
||||
if err != nil {
|
||||
@@ -123,7 +128,7 @@ func (e *ConversionEngine) ConvertHttpRequest(spec HTTPRequestSpec, clientProtoc
|
||||
}
|
||||
|
||||
return &HTTPRequestSpec{
|
||||
URL: provider.BaseURL + providerUrl,
|
||||
URL: joinBaseURL(provider.BaseURL, providerURL),
|
||||
Method: spec.Method,
|
||||
Headers: providerHeaders,
|
||||
Body: providerBody,
|
||||
@@ -135,25 +140,22 @@ func (e *ConversionEngine) ConvertHttpResponse(spec HTTPResponseSpec, clientProt
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议时最小化改写 model 字段
|
||||
if modelOverride != "" && len(spec.Body) > 0 {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
rewrittenBody, err := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if err != nil {
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
rewrittenBody, rewriteErr := adapter.RewriteResponseModelName(spec.Body, modelOverride, interfaceType)
|
||||
if rewriteErr != nil {
|
||||
e.logger.Warn("Smart Passthrough 改写响应失败,使用原始响应体",
|
||||
zap.String("error", err.Error()),
|
||||
zap.Error(rewriteErr),
|
||||
zap.String("interface", string(interfaceType)))
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
} else {
|
||||
return &HTTPResponseSpec{
|
||||
StatusCode: spec.StatusCode,
|
||||
Headers: spec.Headers,
|
||||
Body: rewrittenBody,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
@@ -183,12 +185,11 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
if e.IsPassthrough(clientProtocol, providerProtocol) {
|
||||
// Smart Passthrough: 同协议流式场景需要逐 chunk 改写 model 字段
|
||||
if modelOverride != "" {
|
||||
adapter, err := e.registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
adapter, getErr := e.registry.Get(clientProtocol)
|
||||
if getErr == nil {
|
||||
return NewSmartPassthroughStreamConverter(adapter, modelOverride, interfaceType), nil
|
||||
}
|
||||
}
|
||||
return NewPassthroughStreamConverter(), nil
|
||||
}
|
||||
|
||||
@@ -203,7 +204,7 @@ func (e *ConversionEngine) CreateStreamConverter(clientProtocol, providerProtoco
|
||||
|
||||
ctx := ConversionContext{
|
||||
ConversionID: uuid.New().String(),
|
||||
InterfaceType: InterfaceTypeChat,
|
||||
InterfaceType: interfaceType,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
@@ -273,7 +274,7 @@ func (e *ConversionEngine) convertResponseBody(interfaceType InterfaceType, clie
|
||||
func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
canonicalReq, err := clientAdapter.DecodeRequest(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码请求失败").WithCause(err)
|
||||
return nil, NewRequestJSONParseError("解码请求失败", err)
|
||||
}
|
||||
|
||||
ctx := NewConversionContext(InterfaceTypeChat)
|
||||
@@ -281,6 +282,9 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if containsUnsupportedMultimodal(canonicalReq) {
|
||||
return nil, NewConversionError(ErrorCodeUnsupportedMultimodal, "跨协议暂不支持多模态内容")
|
||||
}
|
||||
|
||||
encoded, err := providerAdapter.EncodeRequest(canonicalReq, provider)
|
||||
if err != nil {
|
||||
@@ -292,7 +296,7 @@ func (e *ConversionEngine) convertChatBody(clientAdapter, providerAdapter Protoc
|
||||
func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
canonicalResp, err := providerAdapter.DecodeResponse(body)
|
||||
if err != nil {
|
||||
return nil, NewConversionError(ErrorCodeJSONParseError, "解码响应失败").WithCause(err)
|
||||
return nil, NewResponseJSONParseError("解码响应失败", err)
|
||||
}
|
||||
if modelOverride != "" {
|
||||
canonicalResp.Model = modelOverride
|
||||
@@ -307,12 +311,12 @@ func (e *ConversionEngine) convertChatResponseBody(clientAdapter, providerAdapte
|
||||
func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
models, err := providerAdapter.DecodeModelsResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelsResponse(models)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("编码 Models 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
@@ -321,12 +325,12 @@ func (e *ConversionEngine) convertModelsResponseBody(clientAdapter, providerAdap
|
||||
func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte) ([]byte, error) {
|
||||
info, err := providerAdapter.DecodeModelInfoResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
encoded, err := clientAdapter.EncodeModelInfoResponse(info)
|
||||
if err != nil {
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("编码 ModelInfo 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return encoded, nil
|
||||
@@ -335,7 +339,7 @@ func (e *ConversionEngine) convertModelInfoResponseBody(clientAdapter, providerA
|
||||
func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeEmbeddingRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeEmbeddingRequest(req, provider)
|
||||
@@ -344,7 +348,7 @@ func (e *ConversionEngine) convertEmbeddingBody(clientAdapter, providerAdapter P
|
||||
func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeEmbeddingResponse(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Embedding 响应失败,返回原始响应", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
if modelOverride != "" {
|
||||
@@ -356,21 +360,22 @@ func (e *ConversionEngine) convertEmbeddingResponseBody(clientAdapter, providerA
|
||||
func (e *ConversionEngine) convertRerankBody(clientAdapter, providerAdapter ProtocolAdapter, provider *TargetProvider, body []byte) ([]byte, error) {
|
||||
req, err := clientAdapter.DecodeRerankRequest(body)
|
||||
if err != nil {
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.String("error", err.Error()))
|
||||
e.logger.Warn("解码 Rerank 请求失败,返回原始请求", zap.Error(err))
|
||||
return body, nil
|
||||
}
|
||||
return providerAdapter.EncodeRerankRequest(req, provider)
|
||||
}
|
||||
|
||||
func (e *ConversionEngine) convertRerankResponseBody(clientAdapter, providerAdapter ProtocolAdapter, body []byte, modelOverride string) ([]byte, error) {
|
||||
resp, err := providerAdapter.DecodeRerankResponse(body)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
}
|
||||
resp, decodeErr := providerAdapter.DecodeRerankResponse(body)
|
||||
if decodeErr == nil {
|
||||
if modelOverride != "" {
|
||||
resp.Model = modelOverride
|
||||
}
|
||||
return clientAdapter.EncodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// DetectInterfaceType 检测接口类型
|
||||
@@ -379,6 +384,7 @@ func (e *ConversionEngine) DetectInterfaceType(nativePath, clientProtocol string
|
||||
if err != nil {
|
||||
return InterfaceTypePassthrough, err
|
||||
}
|
||||
nativePath, _ = splitRequestPath(nativePath)
|
||||
return adapter.DetectInterfaceType(nativePath), nil
|
||||
}
|
||||
|
||||
@@ -392,9 +398,56 @@ func (e *ConversionEngine) EncodeError(err *ConversionError, clientProtocol stri
|
||||
"type": "internal_error",
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(fallback)
|
||||
body, marshalErr := json.Marshal(fallback)
|
||||
if marshalErr == nil {
|
||||
return body, 500, nil
|
||||
}
|
||||
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error"}}`), 500, nil
|
||||
}
|
||||
body, statusCode := adapter.EncodeError(err)
|
||||
return body, statusCode, nil
|
||||
}
|
||||
|
||||
func splitRequestPath(rawPath string) (string, string) {
|
||||
path, query, found := strings.Cut(rawPath, "?")
|
||||
if !found {
|
||||
return rawPath, ""
|
||||
}
|
||||
return path, query
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
if strings.Contains(path, "?") {
|
||||
return path + "&" + rawQuery
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
if baseURL == "" {
|
||||
return path
|
||||
}
|
||||
if path == "" {
|
||||
return baseURL
|
||||
}
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func containsUnsupportedMultimodal(req *canonical.CanonicalRequest) bool {
|
||||
if req == nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range req.Messages {
|
||||
for _, block := range msg.Content {
|
||||
switch block.Type {
|
||||
case "image", "audio", "video", "file":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
63
backend/internal/conversion/engine_adapter_test.go
Normal file
63
backend/internal/conversion/engine_adapter_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package conversion_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConvertHttpRequest_SameProtocolUsesAdapterBuildURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
adapter conversion.ProtocolAdapter
|
||||
clientProtocol string
|
||||
providerProtocol string
|
||||
baseURL string
|
||||
nativePath string
|
||||
expectedURL string
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
name: "openai base url includes version path",
|
||||
adapter: openai.NewAdapter(),
|
||||
clientProtocol: "openai",
|
||||
providerProtocol: "openai",
|
||||
baseURL: "http://example.com/v1",
|
||||
nativePath: "/chat/completions",
|
||||
expectedURL: "http://example.com/v1/chat/completions",
|
||||
body: []byte(`{"model":"gpt-4","messages":[]}`),
|
||||
},
|
||||
{
|
||||
name: "anthropic native path keeps v1",
|
||||
adapter: anthropic.NewAdapter(),
|
||||
clientProtocol: "anthropic",
|
||||
providerProtocol: "anthropic",
|
||||
baseURL: "http://example.com",
|
||||
nativePath: "/v1/messages",
|
||||
expectedURL: "http://example.com/v1/messages",
|
||||
body: []byte(`{"model":"claude","messages":[]}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(tt.adapter))
|
||||
|
||||
out, err := engine.ConvertHttpRequest(conversion.HTTPRequestSpec{
|
||||
URL: tt.nativePath,
|
||||
Method: "POST",
|
||||
Body: tt.body,
|
||||
}, tt.clientProtocol, tt.providerProtocol, conversion.NewTargetProvider(tt.baseURL, "key", "upstream-model"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedURL, out.URL)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversionError_WithProviderProtocol(t *testing.T) {
|
||||
@@ -39,7 +40,7 @@ func TestConversionError_FullBuilder(t *testing.T) {
|
||||
|
||||
func TestEngine_Use(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
called := false
|
||||
engine.Use(&testMiddleware{fn: func(req *canonical.CanonicalRequest, cp, pp string, ctx *ConversionContext) (*canonical.CanonicalRequest, error) {
|
||||
called = true
|
||||
@@ -66,7 +67,7 @@ func TestEngine_Use(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
return nil, errors.New("decode failed")
|
||||
@@ -82,7 +83,7 @@ func TestConvertHttpRequest_DecodeError(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.encodeReqFn = func(req *canonical.CanonicalRequest, p *TargetProvider) ([]byte, error) {
|
||||
@@ -98,7 +99,7 @@ func TestConvertHttpRequest_EncodeError(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
@@ -121,7 +122,7 @@ func TestConvertHttpResponse_CrossProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.decodeRespFn = func(raw []byte) (*canonical.CanonicalResponse, error) {
|
||||
return nil, errors.New("decode error")
|
||||
@@ -135,7 +136,7 @@ func TestConvertHttpResponse_DecodeError(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeEmbeddings
|
||||
@@ -158,7 +159,7 @@ func TestConvertHttpRequest_EmbeddingInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeRerank
|
||||
@@ -178,7 +179,7 @@ func TestConvertHttpRequest_RerankInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeEmbeddings: true}
|
||||
@@ -196,7 +197,7 @@ func TestConvertHttpResponse_EmbeddingInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeRerank: true}
|
||||
@@ -214,7 +215,7 @@ func TestConvertHttpResponse_RerankInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.ifaceType = InterfaceTypeModels
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -232,7 +233,7 @@ func TestConvertHttpRequest_ModelsInterface_Passthrough(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModels: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -249,7 +250,7 @@ func TestConvertHttpResponse_ModelsInterface(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_ModelInfoInterface(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.supportsIface = map[InterfaceType]bool{InterfaceTypeModelInfo: true}
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -324,7 +325,7 @@ var _ = json.Marshal
|
||||
|
||||
func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeEmbeddingReqFn = func(raw []byte) (*canonical.CanonicalEmbeddingRequest, error) {
|
||||
@@ -344,7 +345,7 @@ func TestConvertEmbeddingBody_DecodeError(t *testing.T) {
|
||||
|
||||
func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.decodeRerankReqFn = func(raw []byte) (*canonical.CanonicalRerankRequest, error) {
|
||||
@@ -364,7 +365,7 @@ func TestConvertRerankBody_DecodeError(t *testing.T) {
|
||||
|
||||
func TestConvertBody_UnknownInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
|
||||
@@ -2,6 +2,7 @@ package conversion
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
@@ -190,7 +191,9 @@ func (m *mockProtocolAdapter) RewriteResponseModelName(body []byte, newModel str
|
||||
// noopStreamDecoder 空流式解码器
|
||||
type noopStreamDecoder struct{}
|
||||
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent { return nil }
|
||||
func (d *noopStreamDecoder) ProcessChunk(rawChunk []byte) []canonical.CanonicalStreamEvent {
|
||||
return nil
|
||||
}
|
||||
func (d *noopStreamDecoder) Flush() []canonical.CanonicalStreamEvent { return nil }
|
||||
|
||||
// noopStreamEncoder 空流式编码器
|
||||
@@ -203,7 +206,7 @@ func (e *noopStreamEncoder) Flush() [][]byte
|
||||
|
||||
func TestNewConversionEngine(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine)
|
||||
assert.Equal(t, registry, engine.GetRegistry())
|
||||
}
|
||||
@@ -211,7 +214,7 @@ func TestNewConversionEngine(t *testing.T) {
|
||||
func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||
t.Run("nil_logger_uses_global", func(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
assert.NotNil(t, engine.logger)
|
||||
})
|
||||
|
||||
@@ -219,13 +222,14 @@ func TestNewConversionEngine_LoggerInjection(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
customLogger := zap.NewNop()
|
||||
engine := NewConversionEngine(registry, customLogger)
|
||||
assert.Equal(t, customLogger, engine.logger)
|
||||
assert.NotNil(t, engine.logger)
|
||||
assert.Contains(t, engine.logger.Name(), "conversion.engine")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterAdapter(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
adapter := newMockAdapter("test-proto", true)
|
||||
err := engine.RegisterAdapter(adapter)
|
||||
@@ -237,7 +241,7 @@ func TestRegisterAdapter(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("openai", true)
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
|
||||
@@ -246,7 +250,7 @@ func TestIsPassthrough_SameProtocol(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("anthropic", true))
|
||||
|
||||
@@ -255,7 +259,7 @@ func TestIsPassthrough_DifferentProtocol(t *testing.T) {
|
||||
|
||||
func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("custom", false))
|
||||
|
||||
assert.False(t, engine.IsPassthrough("custom", "custom"))
|
||||
@@ -263,7 +267,7 @@ func TestIsPassthrough_NoPassthrough(t *testing.T) {
|
||||
|
||||
func TestDetectInterfaceType(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
adapter := newMockAdapter("test", true)
|
||||
adapter.ifaceType = InterfaceTypeChat
|
||||
_ = engine.RegisterAdapter(adapter)
|
||||
@@ -275,7 +279,7 @@ func TestDetectInterfaceType(t *testing.T) {
|
||||
|
||||
func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
_, err := engine.DetectInterfaceType("/v1/chat", "nonexistent")
|
||||
assert.Error(t, err)
|
||||
@@ -283,25 +287,39 @@ func TestDetectInterfaceType_NonExistentProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpRequest_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
openaiAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/chat/completions"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
return []byte(`{"model":"` + newModel + `","messages":[{"role":"user","content":"hi"}]}`), nil
|
||||
}
|
||||
_ = engine.RegisterAdapter(openaiAdapter)
|
||||
|
||||
provider := NewTargetProvider("https://api.openai.com/v1", "sk-test", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/chat/completions",
|
||||
URL: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||
Body: []byte(`{"model":"openai/gpt-4","messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "openai", "openai", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||
assert.Equal(t, spec.Body, result.Body)
|
||||
assert.JSONEq(t, `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, string(result.Body))
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client-proto", false)
|
||||
clientAdapter.decodeReqFn = func(raw []byte) (*canonical.CanonicalRequest, error) {
|
||||
@@ -331,9 +349,80 @@ func TestConvertHttpRequest_CrossProtocol(t *testing.T) {
|
||||
assert.NotNil(t, result.Body)
|
||||
}
|
||||
|
||||
func TestConvertHttpRequest_UsesProviderAdapterBuildURL(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
openaiAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("openai", true),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/chat/completions"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
openaiAdapter.ifaceType = InterfaceTypeChat
|
||||
openaiAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
openaiAdapter.rewriteReqFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
return []byte(`{"model":"` + newModel + `"}`), nil
|
||||
}
|
||||
require.NoError(t, registry.Register(openaiAdapter))
|
||||
|
||||
anthropicAdapter := &buildURLMockAdapter{
|
||||
mockProtocolAdapter: newMockAdapter("anthropic", false),
|
||||
buildURLFn: func(nativePath string, interfaceType InterfaceType) string {
|
||||
if interfaceType == InterfaceTypeChat {
|
||||
return "/v1/messages"
|
||||
}
|
||||
return nativePath
|
||||
},
|
||||
}
|
||||
anthropicAdapter.ifaceType = InterfaceTypeChat
|
||||
anthropicAdapter.supportsIface[InterfaceTypeChat] = true
|
||||
require.NoError(t, registry.Register(anthropicAdapter))
|
||||
|
||||
t.Run("OpenAI to Anthropic", func(t *testing.T) {
|
||||
provider := NewTargetProvider("https://api.anthropic.com", "key", "claude-3")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"max_tokens":16}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "openai", "anthropic", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.anthropic.com/v1/messages", result.URL)
|
||||
})
|
||||
|
||||
t.Run("Anthropic to OpenAI", func(t *testing.T) {
|
||||
provider := NewTargetProvider("https://api.openai.com/v1", "key", "gpt-4")
|
||||
spec := HTTPRequestSpec{
|
||||
URL: "/v1/messages",
|
||||
Method: "POST",
|
||||
Body: []byte(`{"model":"p1/claude-3","max_tokens":16,"messages":[{"role":"user","content":"hi"}]}`),
|
||||
}
|
||||
|
||||
result, err := engine.ConvertHttpRequest(spec, "anthropic", "openai", provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://api.openai.com/v1/chat/completions", result.URL)
|
||||
})
|
||||
}
|
||||
|
||||
type buildURLMockAdapter struct {
|
||||
*mockProtocolAdapter
|
||||
buildURLFn func(string, InterfaceType) string
|
||||
}
|
||||
|
||||
func (m *buildURLMockAdapter) BuildUrl(nativePath string, interfaceType InterfaceType) string {
|
||||
if m.buildURLFn != nil {
|
||||
return m.buildURLFn(nativePath, interfaceType)
|
||||
}
|
||||
return m.mockProtocolAdapter.BuildUrl(nativePath, interfaceType)
|
||||
}
|
||||
|
||||
func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
spec := HTTPResponseSpec{
|
||||
@@ -349,7 +438,7 @@ func TestConvertHttpResponse_Passthrough(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
converter, err := engine.CreateStreamConverter("openai", "openai", "", InterfaceTypeChat)
|
||||
@@ -360,7 +449,7 @@ func TestCreateStreamConverter_Passthrough(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("client", false))
|
||||
_ = engine.RegisterAdapter(newMockAdapter("provider", false))
|
||||
|
||||
@@ -372,7 +461,7 @@ func TestCreateStreamConverter_Canonical(t *testing.T) {
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
_ = engine.RegisterAdapter(newMockAdapter("openai", true))
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
@@ -384,7 +473,7 @@ func TestEncodeError(t *testing.T) {
|
||||
|
||||
func TestEncodeError_NonExistentProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
convErr := NewConversionError(ErrorCodeInvalidInput, "测试错误")
|
||||
body, statusCode, err := engine.EncodeError(convErr, "nonexistent")
|
||||
@@ -417,7 +506,7 @@ func TestRegistry_GetNonExistent(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
clientAdapter := newMockAdapter("client", false)
|
||||
clientAdapter.encodeRespFn = func(resp *canonical.CanonicalResponse) ([]byte, error) {
|
||||
@@ -446,7 +535,7 @@ func TestConvertHttpResponse_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
|
||||
func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// 使用真实 OpenAI adapter 验证 Smart Passthrough 改写
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
@@ -476,7 +565,7 @@ func TestConvertHttpResponse_ModelOverride_SameProtocol(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
openaiAdapter := newMockAdapter("openai", true)
|
||||
openaiAdapter.rewriteRespFn = func(body []byte, newModel string, ifaceType InterfaceType) ([]byte, error) {
|
||||
@@ -495,18 +584,19 @@ func TestCreateStreamConverter_ModelOverride_SmartPassthrough(t *testing.T) {
|
||||
_, ok := converter.(*SmartPassthroughStreamConverter)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 验证 chunk 改写
|
||||
chunks := converter.ProcessChunk([]byte(`{"model":"gpt-4","choices":[]}`))
|
||||
// 验证 SSE frame 中的 data JSON 被改写
|
||||
chunks := converter.ProcessChunk([]byte(`data: {"model":"gpt-4","choices":[]}` + "\n\n"))
|
||||
require.Len(t, chunks, 1)
|
||||
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(chunks[0], &resp))
|
||||
payload := strings.TrimPrefix(strings.TrimSpace(string(chunks[0])), "data: ")
|
||||
require.NoError(t, json.Unmarshal([]byte(payload), &resp))
|
||||
assert.Equal(t, "openai/gpt-4", resp["model"])
|
||||
}
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
// provider adapter 解码出含 model 的流式事件
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
@@ -560,7 +650,7 @@ func TestCreateStreamConverter_ModelOverride_CrossProtocol(t *testing.T) {
|
||||
|
||||
func TestCreateStreamConverter_ModelOverride_CrossProtocol_Empty(t *testing.T) {
|
||||
registry := NewMemoryRegistry()
|
||||
engine := NewConversionEngine(registry, nil)
|
||||
engine := NewConversionEngine(registry, zap.NewNop())
|
||||
|
||||
providerAdapter := newMockAdapter("provider", false)
|
||||
providerAdapter.streamDecoderFn = func() StreamDecoder {
|
||||
@@ -614,6 +704,7 @@ func (d *engineTestStreamDecoder) ProcessChunk(raw []byte) []canonical.Canonical
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *engineTestStreamDecoder) Flush() []canonical.CanonicalStreamEvent {
|
||||
if d.flushFn != nil {
|
||||
return d.flushFn()
|
||||
@@ -633,6 +724,7 @@ func (e *engineTestStreamEncoder) EncodeEvent(event canonical.CanonicalStreamEve
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *engineTestStreamEncoder) Flush() [][]byte {
|
||||
if e.flushFn != nil {
|
||||
return e.flushFn()
|
||||
|
||||
@@ -17,6 +17,13 @@ const (
|
||||
ErrorCodeProtocolConstraint ErrorCode = "PROTOCOL_CONSTRAINT_VIOLATION"
|
||||
ErrorCodeEncodingFailure ErrorCode = "ENCODING_FAILURE"
|
||||
ErrorCodeInterfaceNotSupported ErrorCode = "INTERFACE_NOT_SUPPORTED"
|
||||
ErrorCodeUnsupportedMultimodal ErrorCode = "UNSUPPORTED_MULTIMODAL"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrorDetailPhase = "phase"
|
||||
ErrorPhaseRequest = "request"
|
||||
ErrorPhaseResponse = "response"
|
||||
)
|
||||
|
||||
// ConversionError 协议转换错误
|
||||
@@ -39,6 +46,20 @@ func NewConversionError(code ErrorCode, message string) *ConversionError {
|
||||
}
|
||||
}
|
||||
|
||||
// NewRequestJSONParseError 创建请求 JSON 解析错误。
|
||||
func NewRequestJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseRequest).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// NewResponseJSONParseError 创建响应 JSON 解析错误。
|
||||
func NewResponseJSONParseError(message string, cause error) *ConversionError {
|
||||
return NewConversionError(ErrorCodeJSONParseError, message).
|
||||
WithDetail(ErrorDetailPhase, ErrorPhaseResponse).
|
||||
WithCause(cause)
|
||||
}
|
||||
|
||||
// WithClientProtocol 设置客户端协议
|
||||
func (e *ConversionError) WithClientProtocol(protocol string) *ConversionError {
|
||||
e.ClientProtocol = protocol
|
||||
|
||||
@@ -29,27 +29,27 @@ func (a *Adapter) SupportsPassthrough() bool { return true }
|
||||
// DetectInterfaceType 根据路径检测接口类型
|
||||
func (a *Adapter) DetectInterfaceType(nativePath string) conversion.InterfaceType {
|
||||
switch {
|
||||
case nativePath == "/chat/completions":
|
||||
case nativePath == "/v1/chat/completions":
|
||||
return conversion.InterfaceTypeChat
|
||||
case nativePath == "/models":
|
||||
case nativePath == "/v1/models":
|
||||
return conversion.InterfaceTypeModels
|
||||
case isModelInfoPath(nativePath):
|
||||
return conversion.InterfaceTypeModelInfo
|
||||
case nativePath == "/embeddings":
|
||||
case nativePath == "/v1/embeddings":
|
||||
return conversion.InterfaceTypeEmbeddings
|
||||
case nativePath == "/rerank":
|
||||
case nativePath == "/v1/rerank":
|
||||
return conversion.InterfaceTypeRerank
|
||||
default:
|
||||
return conversion.InterfaceTypePassthrough
|
||||
}
|
||||
}
|
||||
|
||||
// isModelInfoPath 判断是否为模型详情路径(/models/{id},允许 id 含 /)
|
||||
// isModelInfoPath 判断是否为模型详情路径(/v1/models/{id},允许 id 含 /)
|
||||
func isModelInfoPath(path string) bool {
|
||||
if !strings.HasPrefix(path, "/models/") {
|
||||
if !strings.HasPrefix(path, "/v1/models/") {
|
||||
return false
|
||||
}
|
||||
suffix := path[len("/models/"):]
|
||||
suffix := path[len("/v1/models/"):]
|
||||
return suffix != ""
|
||||
}
|
||||
|
||||
@@ -60,6 +60,11 @@ func (a *Adapter) BuildUrl(nativePath string, interfaceType conversion.Interface
|
||||
return "/chat/completions"
|
||||
case conversion.InterfaceTypeModels:
|
||||
return "/models"
|
||||
case conversion.InterfaceTypeModelInfo:
|
||||
if modelID, err := a.ExtractUnifiedModelID(nativePath); err == nil {
|
||||
return "/models/" + modelID
|
||||
}
|
||||
return nativePath
|
||||
case conversion.InterfaceTypeEmbeddings:
|
||||
return "/embeddings"
|
||||
case conversion.InterfaceTypeRerank:
|
||||
@@ -138,7 +143,10 @@ func (a *Adapter) EncodeError(err *conversion.ConversionError) ([]byte, int) {
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(errMsg)
|
||||
body, marshalErr := json.Marshal(errMsg)
|
||||
if marshalErr != nil {
|
||||
return []byte(`{"error":{"message":"internal error","type":"internal_error","code":"INTERNAL_ERROR"}}`), statusCode
|
||||
}
|
||||
return body, statusCode
|
||||
}
|
||||
|
||||
@@ -218,12 +226,12 @@ func (a *Adapter) EncodeRerankResponse(resp *canonical.CanonicalRerankResponse)
|
||||
return encodeRerankResponse(resp)
|
||||
}
|
||||
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/models/{provider_id}/{model_name})
|
||||
// ExtractUnifiedModelID 从路径中提取统一模型 ID(/v1/models/{provider_id}/{model_name})
|
||||
func (a *Adapter) ExtractUnifiedModelID(nativePath string) (string, error) {
|
||||
if !strings.HasPrefix(nativePath, "/models/") {
|
||||
if !strings.HasPrefix(nativePath, "/v1/models/") {
|
||||
return "", fmt.Errorf("不是模型详情路径: %s", nativePath)
|
||||
}
|
||||
suffix := nativePath[len("/models/"):]
|
||||
suffix := nativePath[len("/v1/models/"):]
|
||||
if suffix == "" {
|
||||
return "", fmt.Errorf("路径缺少模型 ID")
|
||||
}
|
||||
@@ -248,7 +256,11 @@ func locateModelFieldInRequest(body []byte, ifaceType conversion.InterfaceType)
|
||||
return "", nil, err
|
||||
}
|
||||
rewriteFunc := func(newModel string) ([]byte, error) {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
}
|
||||
return current, rewriteFunc, nil
|
||||
@@ -282,12 +294,20 @@ func (a *Adapter) RewriteResponseModelName(body []byte, newModel string, ifaceTy
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings:
|
||||
// Chat/Embedding 响应必须有 model 字段(协议要求),存在则改写,不存在则添加
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
return json.Marshal(m)
|
||||
case conversion.InterfaceTypeRerank:
|
||||
// Rerank 响应:存在 model 字段则改写,不存在则不添加
|
||||
if _, exists := m["model"]; exists {
|
||||
m["model"], _ = json.Marshal(newModel)
|
||||
encodedModel, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m["model"] = encodedModel
|
||||
}
|
||||
return json.Marshal(m)
|
||||
default:
|
||||
|
||||
@@ -28,11 +28,11 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"聊天补全", "/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/models/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/rerank", conversion.InterfaceTypeRerank},
|
||||
{"聊天补全", "/v1/chat/completions", conversion.InterfaceTypeChat},
|
||||
{"模型列表", "/v1/models", conversion.InterfaceTypeModels},
|
||||
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo},
|
||||
{"嵌入接口", "/v1/embeddings", conversion.InterfaceTypeEmbeddings},
|
||||
{"重排序接口", "/v1/rerank", conversion.InterfaceTypeRerank},
|
||||
{"未知路径", "/unknown", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
@@ -44,6 +44,27 @@ func TestAdapter_DetectInterfaceType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_OldPathsBecomePassthrough(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expected conversion.InterfaceType
|
||||
}{
|
||||
{"/chat/completions", conversion.InterfaceTypePassthrough},
|
||||
{"/models", conversion.InterfaceTypePassthrough},
|
||||
{"/models/gpt-4.1", conversion.InterfaceTypePassthrough},
|
||||
{"/embeddings", conversion.InterfaceTypePassthrough},
|
||||
{"/rerank", conversion.InterfaceTypePassthrough},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, a.DetectInterfaceType(tt.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildUrl(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
@@ -53,10 +74,12 @@ func TestAdapter_BuildUrl(t *testing.T) {
|
||||
interfaceType conversion.InterfaceType
|
||||
expected string
|
||||
}{
|
||||
{"聊天", "/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
||||
{"模型", "/models", conversion.InterfaceTypeModels, "/models"},
|
||||
{"嵌入", "/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
|
||||
{"重排序", "/rerank", conversion.InterfaceTypeRerank, "/rerank"},
|
||||
{"聊天", "/v1/chat/completions", conversion.InterfaceTypeChat, "/chat/completions"},
|
||||
{"模型", "/v1/models", conversion.InterfaceTypeModels, "/models"},
|
||||
{"模型详情", "/v1/models/openai/gpt-4", conversion.InterfaceTypeModelInfo, "/models/openai/gpt-4"},
|
||||
{"复杂模型详情", "/v1/models/azure/accounts/org/models/gpt-4", conversion.InterfaceTypeModelInfo, "/models/azure/accounts/org/models/gpt-4"},
|
||||
{"嵌入", "/v1/embeddings", conversion.InterfaceTypeEmbeddings, "/embeddings"},
|
||||
{"重排序", "/v1/rerank", conversion.InterfaceTypeRerank, "/rerank"},
|
||||
{"默认透传", "/other", conversion.InterfaceTypePassthrough, "/other"},
|
||||
}
|
||||
|
||||
@@ -118,12 +141,12 @@ func TestIsModelInfoPath(t *testing.T) {
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"model_info", "/models/gpt-4", true},
|
||||
{"model_info_with_dots", "/models/gpt-4.1-preview", true},
|
||||
{"models_list", "/models", false},
|
||||
{"nested_path", "/models/gpt-4/versions", true},
|
||||
{"empty_suffix", "/models/", false},
|
||||
{"unrelated", "/chat/completions", false},
|
||||
{"model_info", "/v1/models/openai/gpt-4", true},
|
||||
{"model_info_with_dots", "/v1/models/openai/gpt-4.1-preview", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"nested_path", "/v1/models/azure/accounts/org-123/models/gpt-4", true},
|
||||
{"empty_suffix", "/v1/models/", false},
|
||||
{"unrelated", "/v1/chat/completions", false},
|
||||
{"partial_prefix", "/model", false},
|
||||
}
|
||||
|
||||
@@ -134,6 +157,27 @@ func TestIsModelInfoPath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_ExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("标准路径", func(t *testing.T) {
|
||||
modelID, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", modelID)
|
||||
})
|
||||
|
||||
t.Run("复杂路径", func(t *testing.T) {
|
||||
modelID, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", modelID)
|
||||
})
|
||||
|
||||
t.Run("非模型详情路径报错", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdapter_EncodeError_InvalidInput(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "参数无效")
|
||||
|
||||
@@ -18,35 +18,35 @@ func TestExtractUnifiedModelID(t *testing.T) {
|
||||
a := NewAdapter()
|
||||
|
||||
t.Run("standard_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/models/openai/gpt-4")
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/openai/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "openai/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("multi_segment_path", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/models/azure/accounts/org/models/gpt-4")
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/azure/accounts/org/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "azure/accounts/org/models/gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("single_segment", func(t *testing.T) {
|
||||
id, err := a.ExtractUnifiedModelID("/models/gpt-4")
|
||||
id, err := a.ExtractUnifiedModelID("/v1/models/gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", id)
|
||||
})
|
||||
|
||||
t.Run("non_model_path", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/chat/completions")
|
||||
_, err := a.ExtractUnifiedModelID("/v1/chat/completions")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty_suffix", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/models/")
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models/")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("models_list_no_slash", func(t *testing.T) {
|
||||
_, err := a.ExtractUnifiedModelID("/models")
|
||||
_, err := a.ExtractUnifiedModelID("/v1/models")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -344,12 +344,12 @@ func TestIsModelInfoPath_UnifiedModelID(t *testing.T) {
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"simple_model_id", "/models/gpt-4", true},
|
||||
{"unified_model_id_with_slash", "/models/openai/gpt-4", true},
|
||||
{"models_list", "/models", false},
|
||||
{"models_list_trailing_slash", "/models/", false},
|
||||
{"chat_completions", "/chat/completions", false},
|
||||
{"deeply_nested", "/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||
{"simple_model_id", "/v1/models/gpt-4", true},
|
||||
{"unified_model_id_with_slash", "/v1/models/openai/gpt-4", true},
|
||||
{"models_list", "/v1/models", false},
|
||||
{"models_list_trailing_slash", "/v1/models/", false},
|
||||
{"chat_completions", "/v1/chat/completions", false},
|
||||
{"deeply_nested", "/v1/models/azure/eastus/deployments/my-dept/models/gpt-4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -215,10 +215,16 @@ func decodeUserContent(content any) []canonical.ContentBlock {
|
||||
var blocks []canonical.ContentBlock
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
blocks = append(blocks, canonical.NewTextBlock(text))
|
||||
case "image_url":
|
||||
blocks = append(blocks, canonical.ContentBlock{Type: "image"})
|
||||
@@ -256,13 +262,22 @@ func decodeContentParts(content any) []contentPart {
|
||||
var result []contentPart
|
||||
for _, item := range parts {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
t, _ := m["type"].(string)
|
||||
t, ok := m["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch t {
|
||||
case "text":
|
||||
text, _ := m["text"].(string)
|
||||
text, ok := m["text"].(string)
|
||||
if !ok {
|
||||
text = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "text", Text: text})
|
||||
case "refusal":
|
||||
refusal, _ := m["refusal"].(string)
|
||||
refusal, ok := m["refusal"].(string)
|
||||
if !ok {
|
||||
refusal = ""
|
||||
}
|
||||
result = append(result, contentPart{Type: "refusal", Refusal: refusal})
|
||||
}
|
||||
}
|
||||
@@ -307,21 +322,33 @@ func decodeToolChoice(toolChoice any) *canonical.ToolChoice {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
case map[string]any:
|
||||
t, _ := v["type"].(string)
|
||||
t, ok := v["type"].(string)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch t {
|
||||
case "function":
|
||||
if fn, ok := v["function"].(map[string]any); ok {
|
||||
name, _ := fn["name"].(string)
|
||||
name, ok := fn["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "custom":
|
||||
if custom, ok := v["custom"].(map[string]any); ok {
|
||||
name, _ := custom["name"].(string)
|
||||
name, ok := custom["name"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
return canonical.NewToolChoiceNamed(name)
|
||||
}
|
||||
case "allowed_tools":
|
||||
if at, ok := v["allowed_tools"].(map[string]any); ok {
|
||||
mode, _ := at["mode"].(string)
|
||||
mode, ok := at["mode"].(string)
|
||||
if !ok {
|
||||
mode = ""
|
||||
}
|
||||
if mode == "required" {
|
||||
return canonical.NewToolChoiceAny()
|
||||
}
|
||||
|
||||
@@ -45,9 +45,11 @@ func TestEncodeRequest_SystemInjection(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, msgs, 2)
|
||||
firstMsg := msgs[0].(map[string]any)
|
||||
firstMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "system", firstMsg["role"])
|
||||
assert.Equal(t, "你是助手", firstMsg["content"])
|
||||
}
|
||||
@@ -72,12 +74,15 @@ func TestEncodeRequest_ToolCalls(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
msgs := result["messages"].([]any)
|
||||
assistantMsg := msgs[0].(map[string]any)
|
||||
msgs, ok := result["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assistantMsg, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
toolCalls, ok := assistantMsg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, toolCalls, 1)
|
||||
tc := toolCalls[0].(map[string]any)
|
||||
tc, ok := toolCalls[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "call_1", tc["id"])
|
||||
}
|
||||
|
||||
@@ -115,9 +120,12 @@ func TestEncodeResponse_Basic(t *testing.T) {
|
||||
assert.Equal(t, "resp-1", result["id"])
|
||||
assert.Equal(t, "chat.completion", result["object"])
|
||||
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
msg := choice["message"].(map[string]any)
|
||||
choices, ok := result["choices"].([]any)
|
||||
require.True(t, ok)
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
msg, ok := choice["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "你好", msg["content"])
|
||||
assert.Equal(t, "stop", choice["finish_reason"])
|
||||
}
|
||||
@@ -137,8 +145,12 @@ func TestEncodeResponse_ToolUse(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okc := result["choices"].([]any)
|
||||
require.True(t, okc)
|
||||
msgMap, okm := choices[0].(map[string]any)
|
||||
require.True(t, okm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
tcs, ok := msg["tool_calls"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, tcs, 1)
|
||||
@@ -158,7 +170,8 @@ func TestEncodeModelsResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "list", result["object"])
|
||||
data := result["data"].([]any)
|
||||
data, okd := result["data"].([]any)
|
||||
require.True(t, okd)
|
||||
assert.Len(t, data, 2)
|
||||
}
|
||||
|
||||
@@ -317,8 +330,12 @@ func TestEncodeResponse_Thinking(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
msg := choices[0].(map[string]any)["message"].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
msg, okmsg := msgMap["message"].(map[string]any)
|
||||
require.True(t, okmsg)
|
||||
assert.Equal(t, "回答", msg["content"])
|
||||
assert.Equal(t, "思考过程", msg["reasoning_content"])
|
||||
}
|
||||
|
||||
@@ -27,8 +27,12 @@ func TestStreamEncoder_MessageStart(t *testing.T) {
|
||||
data := strings.TrimPrefix(s, "data: ")
|
||||
data = strings.TrimRight(data, "\n")
|
||||
require.NoError(t, json.Unmarshal([]byte(data), &payload))
|
||||
choices := payload["choices"].([]any)
|
||||
delta := choices[0].(map[string]any)["delta"].(map[string]any)
|
||||
choices, okch := payload["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
msgMap, okmm := choices[0].(map[string]any)
|
||||
require.True(t, okmm)
|
||||
delta, okd := msgMap["delta"].(map[string]any)
|
||||
require.True(t, okd)
|
||||
assert.Equal(t, "assistant", delta["role"])
|
||||
}
|
||||
|
||||
|
||||
@@ -177,7 +177,8 @@ func TestEncodeRerankResponse(t *testing.T) {
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
assert.Equal(t, "rerank-1", result["model"])
|
||||
results := result["results"].([]any)
|
||||
results, okr := result["results"].([]any)
|
||||
require.True(t, okr)
|
||||
assert.Len(t, results, 1)
|
||||
}
|
||||
|
||||
@@ -373,7 +374,8 @@ func TestEncodeResponse_UsageWithCacheAndReasoning(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
usage := result["usage"].(map[string]any)
|
||||
usage, oku := result["usage"].(map[string]any)
|
||||
require.True(t, oku)
|
||||
assert.Equal(t, float64(100), usage["prompt_tokens"])
|
||||
ptd, ok := usage["prompt_tokens_details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
@@ -412,8 +414,10 @@ func TestEncodeResponse_StopReasons(t *testing.T) {
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal(body, &result))
|
||||
choices := result["choices"].([]any)
|
||||
choice := choices[0].(map[string]any)
|
||||
choices, okch := result["choices"].([]any)
|
||||
require.True(t, okch)
|
||||
choice, okc := choices[0].(map[string]any)
|
||||
require.True(t, okc)
|
||||
assert.Equal(t, tt.want, choice["finish_reason"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package conversion
|
||||
|
||||
import "nex/backend/internal/conversion/canonical"
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
)
|
||||
|
||||
// StreamDecoder 流式解码器接口
|
||||
type StreamDecoder interface {
|
||||
@@ -39,11 +44,12 @@ func (c *PassthroughStreamConverter) Flush() [][]byte {
|
||||
}
|
||||
|
||||
// SmartPassthroughStreamConverter 同协议 Smart Passthrough 流式转换器
|
||||
// 逐 chunk 改写 model 字段
|
||||
// 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
type SmartPassthroughStreamConverter struct {
|
||||
adapter ProtocolAdapter
|
||||
modelOverride string
|
||||
interfaceType InterfaceType
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// NewSmartPassthroughStreamConverter 创建 Smart Passthrough 流式转换器
|
||||
@@ -55,24 +61,45 @@ func NewSmartPassthroughStreamConverter(adapter ProtocolAdapter, modelOverride s
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessChunk 改写 chunk 中的 model 字段
|
||||
// ProcessChunk 按 SSE frame 改写 data JSON 中的 model 字段
|
||||
func (c *SmartPassthroughStreamConverter) ProcessChunk(rawChunk []byte) [][]byte {
|
||||
if len(rawChunk) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rewrittenChunk, err := c.adapter.RewriteResponseModelName(rawChunk, c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
// 改写失败,返回原始 chunk
|
||||
return [][]byte{rawChunk}
|
||||
}
|
||||
c.buffer = append(c.buffer, rawChunk...)
|
||||
frames, rest := splitSSEFrames(c.buffer)
|
||||
c.buffer = rest
|
||||
|
||||
return [][]byte{rewrittenChunk}
|
||||
result := make([][]byte, 0, len(frames))
|
||||
for _, frame := range frames {
|
||||
result = append(result, c.rewriteFrame(frame))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush 无缓冲数据
|
||||
func (c *SmartPassthroughStreamConverter) rewriteFrame(frame []byte) []byte {
|
||||
payload, ok := sseFrameDataPayload(frame)
|
||||
if !ok || strings.TrimSpace(payload) == "[DONE]" {
|
||||
return frame
|
||||
}
|
||||
|
||||
rewrittenPayload, err := c.adapter.RewriteResponseModelName([]byte(payload), c.modelOverride, c.interfaceType)
|
||||
if err != nil {
|
||||
return frame
|
||||
}
|
||||
|
||||
return rebuildSSEFrameWithData(frame, string(rewrittenPayload))
|
||||
}
|
||||
|
||||
// Flush 输出未形成完整 frame 的剩余数据
|
||||
func (c *SmartPassthroughStreamConverter) Flush() [][]byte {
|
||||
if len(c.buffer) == 0 {
|
||||
return nil
|
||||
}
|
||||
frame := append([]byte(nil), c.buffer...)
|
||||
c.buffer = nil
|
||||
return [][]byte{c.rewriteFrame(frame)}
|
||||
}
|
||||
|
||||
// CanonicalStreamConverter 跨协议规范流式转换器
|
||||
@@ -153,3 +180,86 @@ func (c *CanonicalStreamConverter) applyModelOverride(event *canonical.Canonical
|
||||
event.Message.Model = c.modelOverride
|
||||
}
|
||||
}
|
||||
|
||||
func splitSSEFrames(data []byte) ([][]byte, []byte) {
|
||||
var frames [][]byte
|
||||
for len(data) > 0 {
|
||||
idx, sepLen := findSSEFrameSeparator(data)
|
||||
if idx < 0 {
|
||||
break
|
||||
}
|
||||
end := idx + sepLen
|
||||
frames = append(frames, append([]byte(nil), data[:end]...))
|
||||
data = data[end:]
|
||||
}
|
||||
return frames, data
|
||||
}
|
||||
|
||||
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||
lf := bytes.Index(data, []byte("\n\n"))
|
||||
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0 && crlf < 0:
|
||||
return -1, 0
|
||||
case lf < 0:
|
||||
return crlf, 4
|
||||
case crlf < 0:
|
||||
return lf, 2
|
||||
case crlf <= lf:
|
||||
return crlf, 4
|
||||
default:
|
||||
return lf, 2
|
||||
}
|
||||
}
|
||||
|
||||
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
var dataLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
value := strings.TrimPrefix(line, "data:")
|
||||
if strings.HasPrefix(value, " ") {
|
||||
value = value[1:]
|
||||
}
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(dataLines, "\n"), true
|
||||
}
|
||||
|
||||
func rebuildSSEFrameWithData(frame []byte, data string) []byte {
|
||||
lineEnding, separator := sseLineEnding(frame)
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
out := make([]string, 0, len(lines)+1)
|
||||
dataWritten := false
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
if !dataWritten {
|
||||
for _, dataLine := range strings.Split(data, "\n") {
|
||||
out = append(out, "data: "+dataLine)
|
||||
}
|
||||
dataWritten = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
out = append(out, line)
|
||||
}
|
||||
if !dataWritten {
|
||||
out = append(out, "data: "+data)
|
||||
}
|
||||
return []byte(strings.Join(out, lineEnding) + separator)
|
||||
}
|
||||
|
||||
func sseLineEnding(frame []byte) (string, string) {
|
||||
if bytes.Contains(frame, []byte("\r\n")) {
|
||||
return "\r\n", "\r\n\r\n"
|
||||
}
|
||||
return "\n", "\n\n"
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -12,22 +11,24 @@ import (
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
func Init(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
db, err := initDB(cfg)
|
||||
moduleLogger := pkglogger.WithModule(zapLogger, "database")
|
||||
|
||||
db, err := initDB(cfg, moduleLogger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := runMigrations(db, cfg.Driver); err != nil {
|
||||
if err := runMigrations(db, cfg.Driver, moduleLogger); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
configurePool(db, cfg)
|
||||
configurePool(db, cfg, moduleLogger)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
@@ -40,36 +41,42 @@ func Close(db *gorm.DB) {
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
func initDB(cfg *config.DatabaseConfig) (*gorm.DB, error) {
|
||||
func initDB(cfg *config.DatabaseConfig, zapLogger *zap.Logger) (*gorm.DB, error) {
|
||||
gormLogger := pkglogger.NewGormLogger(zapLogger)
|
||||
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
Logger: gormLogger,
|
||||
}
|
||||
|
||||
switch cfg.Driver {
|
||||
case "mysql":
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName)
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 MySQL 数据库",
|
||||
zap.String("host", cfg.Host),
|
||||
zap.Int("port", cfg.Port),
|
||||
zap.String("database", cfg.DBName))
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), gormConfig)
|
||||
default:
|
||||
dbDir := filepath.Dir(cfg.Path)
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("连接 SQLite 数据库", zap.String("path", cfg.Path))
|
||||
}
|
||||
return gorm.Open(sqlite.Open(cfg.Path), gormConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func runMigrations(db *gorm.DB, driver string) error {
|
||||
func runMigrations(db *gorm.DB, driver string, zapLogger *zap.Logger) error {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationsDir := getMigrationsDir(driver)
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
gooseDialect := "sqlite3"
|
||||
migrationsSubDir := "sqlite"
|
||||
if driver == "mysql" {
|
||||
@@ -77,19 +84,33 @@ func runMigrations(db *gorm.DB, driver string) error {
|
||||
migrationsSubDir = "mysql"
|
||||
}
|
||||
|
||||
goose.SetDialect(gooseDialect)
|
||||
migrationsDir := getMigrationsDir(driver)
|
||||
if _, err := os.Stat(migrationsDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("迁移目录不存在: %s", migrationsDir)
|
||||
}
|
||||
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("执行数据库迁移",
|
||||
zap.String("dialect", gooseDialect),
|
||||
zap.String("dir", migrationsSubDir))
|
||||
}
|
||||
|
||||
if err := goose.SetDialect(gooseDialect); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := goose.Up(sqlDB, migrationsDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("使用 %s 方言执行迁移,目录: %s", gooseDialect, migrationsSubDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) {
|
||||
func configurePool(db *gorm.DB, cfg *config.DatabaseConfig, zapLogger *zap.Logger) {
|
||||
if cfg.Driver == "sqlite" {
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
log.Printf("警告: 启用 WAL 模式失败: %v", err)
|
||||
if zapLogger != nil {
|
||||
zapLogger.Warn("启用 WAL 模式失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,8 +122,12 @@ func configurePool(db *gorm.DB, cfg *config.DatabaseConfig) {
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
log.Printf("数据库连接池配置: MaxIdle=%d, MaxOpen=%d, MaxLifetime=%v",
|
||||
cfg.MaxIdleConns, cfg.MaxOpenConns, cfg.ConnMaxLifetime)
|
||||
if zapLogger != nil {
|
||||
zapLogger.Info("数据库连接池配置",
|
||||
zap.Int("max_idle_conns", cfg.MaxIdleConns),
|
||||
zap.Int("max_open_conns", cfg.MaxOpenConns),
|
||||
zap.Duration("conn_max_lifetime", cfg.ConnMaxLifetime))
|
||||
}
|
||||
}
|
||||
|
||||
func getMigrationsDir(driver string) string {
|
||||
|
||||
@@ -4,10 +4,11 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestInit_SQLite(t *testing.T) {
|
||||
@@ -20,7 +21,8 @@ func TestInit_SQLite(t *testing.T) {
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
db, err := Init(cfg, nil)
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
defer Close(db)
|
||||
@@ -40,7 +42,8 @@ func TestClose(t *testing.T) {
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
db, err := Init(cfg, nil)
|
||||
zapLogger := zap.NewNop()
|
||||
db, err := Init(cfg, zapLogger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, db)
|
||||
|
||||
|
||||
@@ -13,4 +13,3 @@ type Provider struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func TestProviderHandler_CreateProvider_Success(t *testing.T) {
|
||||
|
||||
@@ -9,23 +9,22 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/tests/mocks"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
|
||||
|
||||
func TestProviderHandler_CreateProvider_MissingFields(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// Logging 日志中间件
|
||||
func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
@@ -15,12 +16,16 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
query := c.Request.URL.RawQuery
|
||||
|
||||
requestID, _ := c.Get(RequestIDKey)
|
||||
logger.Info("请求开始",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.String("query", query),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.Any("request_id", requestID),
|
||||
var requestIDStr string
|
||||
if id, ok := requestID.(string); ok {
|
||||
requestIDStr = id
|
||||
}
|
||||
logger.Debug("请求开始",
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Query(query),
|
||||
pkglogger.ClientIP(c.ClientIP()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
|
||||
c.Next()
|
||||
@@ -28,13 +33,13 @@ func Logging(logger *zap.Logger) gin.HandlerFunc {
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
logger.Info("请求结束",
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Duration("latency", latency),
|
||||
zap.Int("body_size", c.Writer.Size()),
|
||||
zap.Any("request_id", requestID),
|
||||
logger.Debug("请求结束",
|
||||
pkglogger.StatusCode(statusCode),
|
||||
pkglogger.Method(c.Request.Method),
|
||||
pkglogger.Path(path),
|
||||
pkglogger.Latency(latency),
|
||||
pkglogger.BodySize(c.Writer.Size()),
|
||||
pkglogger.RequestID(requestIDStr),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -65,6 +67,61 @@ func TestLogging(t *testing.T) {
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestLogging_DoesNotLogLifecycleAtInfoLevel(t *testing.T) {
|
||||
core, logs := observer.New(zapcore.InfoLevel)
|
||||
logger := zap.New(core)
|
||||
|
||||
w := serveLoggingRequest(logger)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Empty(t, logs.FilterMessage("请求开始").All())
|
||||
assert.Empty(t, logs.FilterMessage("请求结束").All())
|
||||
}
|
||||
|
||||
func TestLogging_LogsLifecycleAtDebugLevel(t *testing.T) {
|
||||
core, logs := observer.New(zapcore.DebugLevel)
|
||||
logger := zap.New(core)
|
||||
|
||||
w := serveLoggingRequest(logger)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
startLogs := logs.FilterMessage("请求开始").All()
|
||||
endLogs := logs.FilterMessage("请求结束").All()
|
||||
if assert.Len(t, startLogs, 1) {
|
||||
fields := startLogs[0].ContextMap()
|
||||
assert.Equal(t, "GET", fields["method"])
|
||||
assert.Equal(t, "/test", fields["path"])
|
||||
assert.Equal(t, "key=value", fields["query"])
|
||||
assert.Equal(t, "existing-id-123", fields["request_id"])
|
||||
assert.NotEmpty(t, fields["client_ip"])
|
||||
}
|
||||
if assert.Len(t, endLogs, 1) {
|
||||
fields := endLogs[0].ContextMap()
|
||||
assert.Equal(t, int64(200), fields["status"])
|
||||
assert.Equal(t, "GET", fields["method"])
|
||||
assert.Equal(t, "/test", fields["path"])
|
||||
assert.Equal(t, int64(2), fields["body_size"])
|
||||
assert.Equal(t, "existing-id-123", fields["request_id"])
|
||||
assert.Contains(t, fields, "latency")
|
||||
}
|
||||
}
|
||||
|
||||
func serveLoggingRequest(logger *zap.Logger) *httptest.ResponseRecorder {
|
||||
r := gin.New()
|
||||
r.Use(RequestID())
|
||||
r.Use(Logging(logger))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(200, "ok")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/test?key=value", nil)
|
||||
req.Header.Set("X-Request-ID", "existing-id-123")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
func TestRecovery_NoPanic(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ModelHandler 模型管理处理器
|
||||
@@ -58,13 +58,13 @@ func (h *ModelHandler) CreateModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Create(model)
|
||||
if err != nil {
|
||||
if err == appErrors.ErrProviderNotFound {
|
||||
if errors.Is(err, appErrors.ErrProviderNotFound) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "供应商不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err == appErrors.ErrDuplicateModel {
|
||||
if errors.Is(err, appErrors.ErrDuplicateModel) {
|
||||
c.JSON(http.StatusConflict, gin.H{
|
||||
"error": "同一供应商下模型名称已存在",
|
||||
"code": appErrors.ErrDuplicateModel.Code,
|
||||
@@ -101,7 +101,7 @@ func (h *ModelHandler) GetModel(c *gin.Context) {
|
||||
|
||||
model, err := h.modelService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
@@ -166,7 +166,7 @@ func (h *ModelHandler) DeleteModel(c *gin.Context) {
|
||||
|
||||
err := h.modelService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "模型未找到",
|
||||
})
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/service"
|
||||
)
|
||||
|
||||
// ProviderHandler 供应商管理处理器
|
||||
@@ -55,7 +55,7 @@ func (h *ProviderHandler) CreateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Create(provider)
|
||||
if err != nil {
|
||||
if err == appErrors.ErrInvalidProviderID {
|
||||
if errors.Is(err, appErrors.ErrInvalidProviderID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": appErrors.ErrInvalidProviderID.Message,
|
||||
"code": appErrors.ErrInvalidProviderID.Code,
|
||||
@@ -86,7 +86,7 @@ func (h *ProviderHandler) GetProvider(c *gin.Context) {
|
||||
|
||||
provider, err := h.providerService.Get(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
@@ -113,7 +113,7 @@ func (h *ProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Update(id, req)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
@@ -145,7 +145,7 @@ func (h *ProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
|
||||
err := h.providerService.Delete(id)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "供应商未找到",
|
||||
})
|
||||
|
||||
@@ -3,19 +3,23 @@ package handler
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/canonical"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
"nex/backend/internal/service"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// ProxyHandler 统一代理处理器
|
||||
@@ -29,14 +33,14 @@ type ProxyHandler struct {
|
||||
}
|
||||
|
||||
// NewProxyHandler 创建统一代理处理器
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService) *ProxyHandler {
|
||||
func NewProxyHandler(engine *conversion.ConversionEngine, client provider.ProviderClient, routingService service.RoutingService, providerService service.ProviderService, statsService service.StatsService, logger *zap.Logger) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
engine: engine,
|
||||
client: client,
|
||||
routingService: routingService,
|
||||
providerService: providerService,
|
||||
statsService: statsService,
|
||||
logger: zap.L(),
|
||||
logger: pkglogger.WithModule(logger, "handler.proxy"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +49,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 从 URL 提取 clientProtocol: /{protocol}/v1/...
|
||||
clientProtocol := c.Param("protocol")
|
||||
if clientProtocol == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少协议前缀"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "缺少协议前缀")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,12 +59,13 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
path = "/" + path
|
||||
}
|
||||
nativePath := path
|
||||
requestPath := appendRawQuery(nativePath, c.Request.URL.RawQuery)
|
||||
|
||||
// 获取 client adapter
|
||||
registry := h.engine.GetRegistry()
|
||||
clientAdapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,7 +82,7 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
if ifaceType == conversion.InterfaceTypeModelInfo {
|
||||
unifiedID, err := clientAdapter.ExtractUnifiedModelID(nativePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的模型 ID 格式"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的模型 ID 格式")
|
||||
return
|
||||
}
|
||||
h.handleModelInfo(c, unifiedID, clientAdapter)
|
||||
@@ -87,40 +92,50 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
// 读取请求体
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_REQUEST", "读取请求体失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析统一模型 ID(使用 adapter.ExtractModelName)
|
||||
var providerID, modelName string
|
||||
if len(body) > 0 {
|
||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||
if err == nil && unifiedID != "" {
|
||||
pid, mn, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err == nil {
|
||||
providerID = pid
|
||||
modelName = mn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建输入 HTTPRequestSpec
|
||||
inSpec := conversion.HTTPRequestSpec{
|
||||
URL: nativePath,
|
||||
URL: requestPath,
|
||||
Method: c.Request.Method,
|
||||
Headers: extractHeaders(c),
|
||||
Body: body,
|
||||
}
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
// 只有 adapter 明确适配的接口才提取 model。未知接口不做通用 model 猜测。
|
||||
if len(body) == 0 || !supportsModelExtraction(ifaceType) {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
unifiedID, err := clientAdapter.ExtractModelName(body, ifaceType)
|
||||
if err != nil {
|
||||
if isInvalidJSONError(err) {
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误")
|
||||
return
|
||||
}
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
// 原始模型名兼容透传:非统一模型 ID 不参与路由。
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
if providerID == "" || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol, ifaceType, isStream)
|
||||
return
|
||||
}
|
||||
|
||||
// 路由
|
||||
routeResult, err := h.routingService.RouteByModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
// GET 请求或无法提取 model 时,直接转发到上游
|
||||
if len(body) == 0 || modelName == "" {
|
||||
h.forwardPassthrough(c, inSpec, clientProtocol)
|
||||
return
|
||||
}
|
||||
h.writeError(c, err, clientProtocol)
|
||||
h.writeRouteError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,9 +155,6 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
routeResult.Model.ModelName, // 上游模型名,用于请求改写
|
||||
)
|
||||
|
||||
// 判断是否流式
|
||||
isStream := h.isStreamRequest(body, clientProtocol, nativePath)
|
||||
|
||||
// 计算统一模型 ID(用于响应覆写)
|
||||
unifiedModelID := routeResult.Model.UnifiedModelID()
|
||||
|
||||
@@ -153,12 +165,34 @@ func (h *ProxyHandler) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func supportsModelExtraction(ifaceType conversion.InterfaceType) bool {
|
||||
switch ifaceType {
|
||||
case conversion.InterfaceTypeChat, conversion.InterfaceTypeEmbeddings, conversion.InterfaceTypeRerank:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isInvalidJSONError(err error) bool {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
return errors.As(err, &syntaxErr) || errors.As(err, &typeErr)
|
||||
}
|
||||
|
||||
func appendRawQuery(path, rawQuery string) string {
|
||||
if rawQuery == "" {
|
||||
return path
|
||||
}
|
||||
return path + "?" + rawQuery
|
||||
}
|
||||
|
||||
// handleNonStream 处理非流式请求
|
||||
func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, targetProvider *conversion.TargetProvider, routeResult *domain.RouteResult, unifiedModelID string, ifaceType conversion.InterfaceType) {
|
||||
// 转换请求
|
||||
outSpec, err := h.engine.ConvertHttpRequest(inSpec, clientProtocol, providerProtocol, targetProvider)
|
||||
if err != nil {
|
||||
h.logger.Error("转换请求失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换请求失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
@@ -166,31 +200,27 @@ func (h *ProxyHandler) handleNonStream(c *gin.Context, inSpec conversion.HTTPReq
|
||||
// 发送请求
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.logger.Error("发送请求失败", zap.String("error", err.Error()))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.logger.Error("发送请求失败", zap.Error(err))
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换响应,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
convertedResp, err := h.engine.ConvertHttpResponse(*resp, clientProtocol, providerProtocol, ifaceType, unifiedModelID)
|
||||
if err != nil {
|
||||
h.logger.Error("转换响应失败", zap.String("error", err.Error()))
|
||||
h.logger.Error("转换响应失败", zap.Error(err))
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -203,15 +233,23 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
return
|
||||
}
|
||||
|
||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||
// 发送流式请求
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 发送流式请求
|
||||
eventChan, err := h.client.SendStream(c.Request.Context(), *outSpec)
|
||||
// 创建流式转换器,传入 modelOverride(跨协议场景覆写 model 字段)
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, unifiedModelID, ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
@@ -222,37 +260,61 @@ func (h *ProxyHandler) handleStream(c *gin.Context, inSpec conversion.HTTPReques
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("流读取错误", zap.String("error", event.Error.Error()))
|
||||
h.logger.Error("流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
// flush 转换器
|
||||
chunks := streamConverter.Flush()
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
for _, chunk := range chunks {
|
||||
writer.Write(chunk)
|
||||
writer.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName)
|
||||
_ = h.statsService.Record(routeResult.Provider.ID, routeResult.Model.ModelName) //nolint:errcheck // fire-and-forget 统计记录不阻塞请求
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeStreamChunks(writer *bufio.Writer, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
if _, err := writer.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writer.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStreamRequest 判断是否流式请求
|
||||
func (h *ProxyHandler) isStreamRequest(body []byte, clientProtocol, nativePath string) bool {
|
||||
ifaceType, _ := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
ifaceType, err := h.engine.DetectInterfaceType(nativePath, clientProtocol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ifaceType != conversion.InterfaceTypeChat {
|
||||
return false
|
||||
}
|
||||
@@ -271,8 +333,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
||||
// 从数据库查询所有启用的模型
|
||||
models, err := h.providerService.ListEnabledModels()
|
||||
if err != nil {
|
||||
h.logger.Error("查询启用模型失败", zap.String("error", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询模型失败"})
|
||||
h.logger.Error("查询启用模型失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "查询模型失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -293,8 +355,8 @@ func (h *ProxyHandler) handleModelsList(c *gin.Context, adapter conversion.Proto
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelsResponse(modelList)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 Models 响应失败", zap.String("error", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||
h.logger.Error("编码 Models 响应失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -306,17 +368,14 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
|
||||
// 解析统一模型 ID
|
||||
providerID, modelName, err := modelid.ParseUnifiedModelID(unifiedID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "无效的统一模型 ID 格式",
|
||||
"code": "INVALID_MODEL_ID",
|
||||
})
|
||||
h.writeProxyError(c, http.StatusBadRequest, "INVALID_MODEL_ID", "无效的统一模型 ID 格式")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库查询模型
|
||||
model, err := h.providerService.GetModelByProviderAndName(providerID, modelName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "模型未找到"})
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", "模型未找到")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -331,42 +390,104 @@ func (h *ProxyHandler) handleModelInfo(c *gin.Context, unifiedID string, adapter
|
||||
// 使用 adapter 编码返回
|
||||
body, err := adapter.EncodeModelInfoResponse(modelInfo)
|
||||
if err != nil {
|
||||
h.logger.Error("编码 ModelInfo 响应失败", zap.String("error", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "编码响应失败"})
|
||||
h.logger.Error("编码 ModelInfo 响应失败", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", "编码响应失败")
|
||||
return
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/json", body)
|
||||
}
|
||||
|
||||
// writeConversionError 写入转换错误
|
||||
// writeConversionError 写入网关层转换错误
|
||||
func (h *ProxyHandler) writeConversionError(c *gin.Context, err error, clientProtocol string) {
|
||||
if convErr, ok := err.(*conversion.ConversionError); ok {
|
||||
body, statusCode, _ := h.engine.EncodeError(convErr, clientProtocol)
|
||||
c.Data(statusCode, "application/json", body)
|
||||
var convErr *conversion.ConversionError
|
||||
if errors.As(err, &convErr) {
|
||||
statusCode, code, message := mapConversionError(convErr)
|
||||
h.writeProxyError(c, statusCode, code, message)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
h.writeProxyError(c, http.StatusInternalServerError, "CONVERSION_FAILED", err.Error())
|
||||
}
|
||||
|
||||
// writeError 写入路由错误
|
||||
func (h *ProxyHandler) writeError(c *gin.Context, err error, clientProtocol string) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
func mapConversionError(err *conversion.ConversionError) (int, string, string) {
|
||||
switch err.Code {
|
||||
case conversion.ErrorCodeJSONParseError:
|
||||
if phase, ok := err.Details[conversion.ErrorDetailPhase].(string); ok && phase == conversion.ErrorPhaseRequest {
|
||||
return http.StatusBadRequest, "INVALID_JSON", "请求体 JSON 格式错误"
|
||||
}
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
case conversion.ErrorCodeInvalidInput,
|
||||
conversion.ErrorCodeMissingRequiredField,
|
||||
conversion.ErrorCodeProtocolConstraint:
|
||||
return http.StatusBadRequest, "INVALID_REQUEST", err.Message
|
||||
case conversion.ErrorCodeInterfaceNotSupported:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_INTERFACE", err.Message
|
||||
case conversion.ErrorCodeUnsupportedMultimodal:
|
||||
return http.StatusBadRequest, "UNSUPPORTED_MULTIMODAL", err.Message
|
||||
default:
|
||||
return http.StatusInternalServerError, "CONVERSION_FAILED", err.Message
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeRouteError(c *gin.Context, err error) {
|
||||
if appErr, ok := appErrors.AsAppError(err); ok {
|
||||
switch appErr.Code {
|
||||
case appErrors.ErrModelNotFound.Code, appErrors.ErrModelDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "MODEL_NOT_FOUND", appErr.Message)
|
||||
case appErrors.ErrProviderNotFound.Code, appErrors.ErrProviderDisabled.Code:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "PROVIDER_NOT_FOUND", appErr.Message)
|
||||
default:
|
||||
h.writeProxyError(c, appErr.HTTPStatus, "INVALID_REQUEST", appErr.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
h.writeProxyError(c, http.StatusNotFound, "MODEL_NOT_FOUND", err.Error())
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamUnavailable(c *gin.Context, err error) {
|
||||
h.logger.Error("上游不可达", zap.Error(err))
|
||||
h.writeProxyError(c, http.StatusBadGateway, "UPSTREAM_UNAVAILABLE", "上游服务不可达")
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeProxyError(c *gin.Context, status int, code, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": message,
|
||||
"code": code,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeConvertedResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range resp.Headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) writeUpstreamResponse(c *gin.Context, resp conversion.HTTPResponseSpec) {
|
||||
for k, v := range filterHopByHopHeaders(resp.Headers) {
|
||||
c.Header(k, v)
|
||||
}
|
||||
contentType := headerValue(resp.Headers, "Content-Type")
|
||||
c.Data(resp.StatusCode, contentType, resp.Body)
|
||||
}
|
||||
|
||||
// forwardPassthrough 直接转发请求到上游(用于 GET 等无 model 的请求)
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string) {
|
||||
func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTPRequestSpec, clientProtocol string, ifaceType conversion.InterfaceType, isStream bool) {
|
||||
registry := h.engine.GetRegistry()
|
||||
adapter, err := registry.Get(clientProtocol)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的协议: " + clientProtocol})
|
||||
h.writeProxyError(c, http.StatusNotFound, "UNSUPPORTED_INTERFACE", "不支持的协议: "+clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.providerService.List()
|
||||
if err != nil || len(providers) == 0 {
|
||||
h.logger.Warn("无可用供应商转发 GET 请求", zap.String("path", inSpec.URL))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "没有可用的供应商。请先创建供应商和模型。"})
|
||||
h.logger.Warn("无可用供应商转发请求", zap.String("path", inSpec.URL))
|
||||
h.writeProxyError(c, http.StatusNotFound, "PROVIDER_NOT_FOUND", "没有可用的供应商。请先创建供应商和模型。")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -376,19 +497,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
providerProtocol = "openai"
|
||||
}
|
||||
|
||||
ifaceType := adapter.DetectInterfaceType(inSpec.URL)
|
||||
|
||||
targetProvider := conversion.NewTargetProvider(p.BaseURL, p.APIKey, "")
|
||||
|
||||
var outSpec *conversion.HTTPRequestSpec
|
||||
if clientProtocol == providerProtocol {
|
||||
upstreamURL := p.BaseURL + inSpec.URL
|
||||
upstreamPath := adapter.BuildUrl(stripRawQuery(inSpec.URL), ifaceType)
|
||||
upstreamPath = appendRawQuery(upstreamPath, rawQueryFromPath(inSpec.URL))
|
||||
headers := adapter.BuildHeaders(targetProvider)
|
||||
if _, ok := headers["Content-Type"]; !ok {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
outSpec = &conversion.HTTPRequestSpec{
|
||||
URL: upstreamURL,
|
||||
URL: joinBaseURL(p.BaseURL, upstreamPath),
|
||||
Method: inSpec.Method,
|
||||
Headers: headers,
|
||||
Body: inSpec.Body,
|
||||
@@ -401,9 +521,18 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
}
|
||||
}
|
||||
|
||||
if isStream {
|
||||
h.forwardStream(c, *outSpec, clientProtocol, providerProtocol, ifaceType)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.client.Send(c.Request.Context(), *outSpec)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, *resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -413,13 +542,111 @@ func (h *ProxyHandler) forwardPassthrough(c *gin.Context, inSpec conversion.HTTP
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range convertedResp.Headers {
|
||||
c.Header(k, v)
|
||||
h.writeConvertedResponse(c, *convertedResp)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) forwardStream(c *gin.Context, outSpec conversion.HTTPRequestSpec, clientProtocol, providerProtocol string, ifaceType conversion.InterfaceType) {
|
||||
streamResp, err := h.client.SendStream(c.Request.Context(), outSpec)
|
||||
if err != nil {
|
||||
h.writeUpstreamUnavailable(c, err)
|
||||
return
|
||||
}
|
||||
if c.GetHeader("Content-Type") == "" {
|
||||
c.Header("Content-Type", "application/json")
|
||||
if streamResp.StatusCode < http.StatusOK || streamResp.StatusCode >= http.StatusMultipleChoices {
|
||||
h.writeUpstreamResponse(c, conversion.HTTPResponseSpec{
|
||||
StatusCode: streamResp.StatusCode,
|
||||
Headers: streamResp.Headers,
|
||||
Body: streamResp.Body,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Data(convertedResp.StatusCode, "application/json", convertedResp.Body)
|
||||
|
||||
streamConverter, err := h.engine.CreateStreamConverter(clientProtocol, providerProtocol, "", ifaceType)
|
||||
if err != nil {
|
||||
h.writeConversionError(c, err, clientProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
writer := bufio.NewWriter(c.Writer)
|
||||
flushed := false
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
h.logger.Error("透传流读取错误", zap.Error(event.Error))
|
||||
break
|
||||
}
|
||||
if event.Done {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
flushed = true
|
||||
break
|
||||
}
|
||||
chunks := streamConverter.ProcessChunk(event.Data)
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
if !flushed {
|
||||
chunks := streamConverter.Flush()
|
||||
if err := h.writeStreamChunks(writer, chunks); err != nil {
|
||||
h.logger.Warn("透传流式响应写回失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripRawQuery(path string) string {
|
||||
pathOnly, _, _ := strings.Cut(path, "?")
|
||||
return pathOnly
|
||||
}
|
||||
|
||||
func rawQueryFromPath(path string) string {
|
||||
_, rawQuery, found := strings.Cut(path, "?")
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
return rawQuery
|
||||
}
|
||||
|
||||
func joinBaseURL(baseURL, path string) string {
|
||||
return strings.TrimRight(baseURL, "/") + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
|
||||
func headerValue(headers map[string]string, key string) string {
|
||||
for k, v := range headers {
|
||||
if strings.EqualFold(k, key) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func filterHopByHopHeaders(headers map[string]string) map[string]string {
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
hopByHop := map[string]struct{}{
|
||||
"connection": {},
|
||||
"transfer-encoding": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
filtered := make(map[string]string, len(headers))
|
||||
for k, v := range headers {
|
||||
if _, skip := hopByHop[strings.ToLower(k)]; skip {
|
||||
continue
|
||||
}
|
||||
filtered[k] = v
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// extractHeaders 从 Gin context 提取请求头
|
||||
|
||||
@@ -5,33 +5,34 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
"nex/backend/internal/conversion/anthropic"
|
||||
"nex/backend/internal/conversion/openai"
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/provider"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"nex/backend/tests/mocks"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"go.uber.org/zap"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
|
||||
|
||||
func setupProxyEngine(t *testing.T) *conversion.ConversionEngine {
|
||||
t.Helper()
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, nil)
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
||||
return engine
|
||||
@@ -44,6 +45,7 @@ func newTestProxyHandler(engine *conversion.ConversionEngine, client *mocks.Mock
|
||||
routingSvc,
|
||||
providerSvc,
|
||||
statsSvc,
|
||||
zap.NewNop(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -72,7 +74,7 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -91,8 +93,8 @@ func TestProxyHandler_HandleProxy_NonStreamSuccess(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -108,20 +110,20 @@ func TestProxyHandler_HandleProxy_RoutingError_WithBody(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(nil, appErrors.ErrModelNotFound)
|
||||
routingSvc.EXPECT().RouteByModelName("unknown", "model").Return(nil, appErrors.ErrModelNotFound)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
providerSvc.EXPECT().List().Return(nil, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
||||
@@ -130,7 +132,7 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -144,11 +146,12 @@ func TestProxyHandler_HandleProxy_ConversionError(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 502, w.Code)
|
||||
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
||||
@@ -157,7 +160,7 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -171,11 +174,12 @@ func TestProxyHandler_HandleProxy_ClientSendError(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 502, w.Code)
|
||||
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
@@ -184,12 +188,12 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@@ -198,7 +202,7 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -207,13 +211,14 @@ func TestProxyHandler_HandleProxy_StreamSuccess(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
|
||||
assert.Contains(t, w.Body.String(), "Hello")
|
||||
assert.Contains(t, w.Body.String(), "p1/gpt-4")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
||||
@@ -222,12 +227,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
return nil, context.DeadlineExceeded
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
@@ -236,11 +241,12 @@ func TestProxyHandler_HandleProxy_StreamError(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 502, w.Code)
|
||||
assert.JSONEq(t, `{"error":"上游服务不可达","code":"UPSTREAM_UNAVAILABLE"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
|
||||
@@ -260,8 +266,8 @@ func TestProxyHandler_ForwardPassthrough_GET(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -281,11 +287,11 @@ func TestProxyHandler_ForwardPassthrough_UnsupportedProtocol(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/models"}}
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "unknown"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/unknown/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
}
|
||||
|
||||
func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
|
||||
@@ -303,8 +309,8 @@ func TestProxyHandler_ForwardPassthrough_NoProviders(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -328,7 +334,7 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -346,8 +352,8 @@ func TestProxyHandler_HandleProxy_ProviderProtocolDefault(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -370,6 +376,7 @@ func TestProxyHandler_WriteConversionError_NonConversionError(t *testing.T) {
|
||||
|
||||
h.writeConversionError(c, context.DeadlineExceeded, "openai")
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.JSONEq(t, `{"error":"context deadline exceeded","code":"CONVERSION_FAILED"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
|
||||
@@ -389,7 +396,40 @@ func TestProxyHandler_WriteConversionError_ConversionError(t *testing.T) {
|
||||
|
||||
convErr := conversion.NewConversionError(conversion.ErrorCodeInvalidInput, "bad request")
|
||||
h.writeConversionError(c, convErr, "openai")
|
||||
assert.Equal(t, 500, w.Code)
|
||||
assert.Equal(t, 400, w.Code)
|
||||
assert.JSONEq(t, `{"error":"bad request","code":"INVALID_REQUEST"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_WriteConversionError_JSONPhase(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
t.Run("request json parse error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
h.writeConversionError(c, conversion.NewRequestJSONParseError("解码请求失败", context.Canceled), "openai")
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("response json parse error", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
h.writeConversionError(c, conversion.NewResponseJSONParseError("解码响应失败", context.Canceled), "openai")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.JSONEq(t, `{"error":"解码响应失败","code":"CONVERSION_FAILED"}`, w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
|
||||
@@ -409,8 +449,8 @@ func TestProxyHandler_HandleProxy_EmptyBody(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -422,19 +462,19 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- provider.StreamEvent{Data: []byte("data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n")}
|
||||
ch <- provider.StreamEvent{Error: fmt.Errorf("connection reset by peer")}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -443,8 +483,8 @@ func TestProxyHandler_HandleStream_MidStreamError(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -459,12 +499,12 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@@ -472,7 +512,7 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -481,8 +521,8 @@ func TestProxyHandler_HandleStream_FlushOutput(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -499,12 +539,12 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, nil)
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
err := registry.Register(openai.NewAdapter())
|
||||
require.NoError(t, err)
|
||||
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -515,8 +555,8 @@ func TestProxyHandler_HandleStream_CreateStreamConverterError(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
@@ -527,11 +567,11 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, nil)
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "nonexistent", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -542,8 +582,8 @@ func TestProxyHandler_HandleStream_ConvertRequestError(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
@@ -554,12 +594,12 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, nil)
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||
require.NoError(t, registry.Register(anthropic.NewAdapter()))
|
||||
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "anthropic", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "claude-3", Enabled: true},
|
||||
}, nil)
|
||||
@@ -577,8 +617,8 @@ func TestProxyHandler_HandleNonStream_ConvertResponseError(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 500, w.Code)
|
||||
@@ -590,7 +630,7 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("", "").Return(&domain.RouteResult{
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
@@ -609,8 +649,8 @@ func TestProxyHandler_HandleNonStream_ResponseHeaders(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -623,7 +663,7 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
registry := conversion.NewMemoryRegistry()
|
||||
engine := conversion.NewConversionEngine(registry, nil)
|
||||
engine := conversion.NewConversionEngine(registry, zap.NewNop())
|
||||
require.NoError(t, registry.Register(openai.NewAdapter()))
|
||||
|
||||
anthropicAdapter := anthropic.NewAdapter()
|
||||
@@ -641,8 +681,8 @@ func TestProxyHandler_ForwardPassthrough_CrossProtocol(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -665,8 +705,8 @@ func TestProxyHandler_ForwardPassthrough_NoBody_NoModel(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -689,10 +729,10 @@ func TestIsStreamRequest_EdgeCases(t *testing.T) {
|
||||
path string
|
||||
expected bool
|
||||
}{
|
||||
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/chat/completions", true},
|
||||
{"stream with spaces", `{"stream" : true}`, "/chat/completions", true},
|
||||
{"stream embedded in string value", `{"model":"stream:true"}`, "/chat/completions", false},
|
||||
{"empty body", "", "/chat/completions", false},
|
||||
{"stream at end of JSON", `{"messages":[],"stream":true}`, "/v1/chat/completions", true},
|
||||
{"stream with spaces", `{"stream" : true}`, "/v1/chat/completions", true},
|
||||
{"stream embedded in string value", `{"model":"stream:true"}`, "/v1/chat/completions", false},
|
||||
{"empty body", "", "/v1/chat/completions", false},
|
||||
{"stream true embeddings", `{"model":"text-emb","stream":true}`, "/v1/embeddings", false},
|
||||
}
|
||||
|
||||
@@ -719,8 +759,9 @@ func TestProxyHandler_WriteError_RouteError(t *testing.T) {
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
h.writeError(c, fmt.Errorf("model not found"), "openai")
|
||||
h.writeRouteError(c, fmt.Errorf("model not found"))
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "MODEL_NOT_FOUND")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
|
||||
@@ -740,8 +781,8 @@ func TestProxyHandler_HandleProxy_RouteEmptyBody_NoModel(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -764,35 +805,35 @@ func TestIsStreamRequest(t *testing.T) {
|
||||
name: "stream true",
|
||||
body: []byte(`{"model": "gpt-4", "stream": true}`),
|
||||
clientProtocol: "openai",
|
||||
nativePath: "/chat/completions",
|
||||
nativePath: "/v1/chat/completions",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "stream false",
|
||||
body: []byte(`{"model": "gpt-4", "stream": false}`),
|
||||
clientProtocol: "openai",
|
||||
nativePath: "/chat/completions",
|
||||
nativePath: "/v1/chat/completions",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no stream field",
|
||||
body: []byte(`{"model": "gpt-4"}`),
|
||||
clientProtocol: "openai",
|
||||
nativePath: "/chat/completions",
|
||||
nativePath: "/v1/chat/completions",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
body: []byte(`{invalid}`),
|
||||
clientProtocol: "openai",
|
||||
nativePath: "/chat/completions",
|
||||
nativePath: "/v1/chat/completions",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "not chat endpoint",
|
||||
body: []byte(`{"model": "gpt-4", "stream": true}`),
|
||||
clientProtocol: "openai",
|
||||
nativePath: "/models",
|
||||
nativePath: "/v1/models",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
@@ -830,8 +871,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -842,7 +883,8 @@ func TestProxyHandler_HandleProxy_Models_LocalAggregation(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
assert.Len(t, data, 2)
|
||||
|
||||
first := data[0].(map[string]interface{})
|
||||
first, ok2 := data[0].(map[string]interface{})
|
||||
require.True(t, ok2)
|
||||
assert.Equal(t, "openai/gpt-4", first["id"])
|
||||
}
|
||||
|
||||
@@ -860,8 +902,8 @@ func TestProxyHandler_HandleProxy_ModelInfo_LocalQuery(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/openai/gpt-4"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models/openai/gpt-4", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/openai/gpt-4"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models/openai/gpt-4", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -894,8 +936,8 @@ func TestProxyHandler_HandleProxy_Models_EmptySuffix_ForwardPassthrough(t *testi
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/models/"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/models/", nil)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/models/"}}
|
||||
c.Request = httptest.NewRequest("GET", "/openai/v1/models/", nil)
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -916,7 +958,7 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
var req map[string]interface{}
|
||||
json.Unmarshal(spec.Body, &req)
|
||||
require.NoError(t, json.Unmarshal(spec.Body, &req))
|
||||
assert.Equal(t, "gpt-4", req["model"])
|
||||
|
||||
return &conversion.HTTPResponseSpec{
|
||||
@@ -932,8 +974,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_UnifiedID(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -970,8 +1012,8 @@ func TestProxyHandler_HandleProxy_CrossProtocol_NonStream_UnifiedID(t *testing.T
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -992,7 +1034,7 @@ func TestProxyHandler_HandleProxy_CrossProtocol_Stream_UnifiedID(t *testing.T) {
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude-3", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan provider.StreamEvent, error) {
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
ch := make(chan provider.StreamEvent, 10)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@@ -1010,7 +1052,7 @@ data: {"type":"message_stop"}
|
||||
`)}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return ch, nil
|
||||
return &provider.StreamResponse{StatusCode: 200, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
@@ -1019,8 +1061,8 @@ data: {"type":"message_stop"}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"anthropic_p/claude-3","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -1057,8 +1099,8 @@ func TestProxyHandler_HandleProxy_SmartPassthrough_Fidelity(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"openai_p/gpt-4","messages":[{"role":"user","content":"hi"}],"custom_param":"should_be_preserved"}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
@@ -1088,8 +1130,8 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"unknown/model","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
assert.Equal(t, 404, w.Code)
|
||||
@@ -1098,3 +1140,314 @@ func TestProxyHandler_HandleProxy_UnifiedID_ModelNotFound(t *testing.T) {
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Contains(t, resp, "error")
|
||||
}
|
||||
|
||||
func TestProxyHandler_HandleProxy_OpenAIAndAnthropicNativePaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
path string
|
||||
requestPath string
|
||||
baseURL string
|
||||
expectedURL string
|
||||
body string
|
||||
responseBody string
|
||||
responseModel string
|
||||
}{
|
||||
{
|
||||
name: "openai path keeps v1 after gateway prefix",
|
||||
protocol: "openai",
|
||||
path: "/v1/chat/completions",
|
||||
requestPath: "/openai/v1/chat/completions",
|
||||
baseURL: "https://api.test.com/v1",
|
||||
expectedURL: "https://api.test.com/v1/chat/completions",
|
||||
body: `{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`,
|
||||
responseBody: `{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`,
|
||||
responseModel: "p1/gpt-4",
|
||||
},
|
||||
{
|
||||
name: "anthropic path keeps v1 after gateway prefix",
|
||||
protocol: "anthropic",
|
||||
path: "/v1/messages",
|
||||
requestPath: "/anthropic/v1/messages",
|
||||
baseURL: "https://api.anthropic.test",
|
||||
expectedURL: "https://api.anthropic.test/v1/messages",
|
||||
body: `{"model":"p1/gpt-4","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`,
|
||||
responseBody: `{"id":"msg-1","type":"message","role":"assistant","model":"gpt-4","content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","usage":{"input_tokens":1,"output_tokens":1}}`,
|
||||
responseModel: "p1/gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: tt.baseURL, Protocol: tt.protocol, Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
assert.Equal(t, tt.expectedURL, spec.URL)
|
||||
return &conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusOK,
|
||||
Headers: map[string]string{"Content-Type": "application/json"},
|
||||
Body: []byte(tt.responseBody),
|
||||
}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: tt.protocol}, {Key: "path", Value: tt.path}}
|
||||
c.Request = httptest.NewRequest("POST", tt.requestPath, bytes.NewReader([]byte(tt.body)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), tt.responseModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHandler_UpstreamNon2xx_Passthrough(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).Return(&conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"X-Upstream-Error": "rate-limit",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
Body: []byte(`{"error":{"message":"rate limited"}}`),
|
||||
}, nil)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}]}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
assert.JSONEq(t, `{"error":{"message":"rate limited"}}`, w.Body.String())
|
||||
assert.Equal(t, "rate-limit", w.Header().Get("X-Upstream-Error"))
|
||||
assert.Empty(t, w.Header().Get("Transfer-Encoding"))
|
||||
}
|
||||
|
||||
func TestProxyHandler_StreamUpstreamNon2xx_Passthrough(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).Return(&provider.StreamResponse{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Headers: map[string]string{"Content-Type": "application/json", "Connection": "close"},
|
||||
Body: []byte(`{"error":"upstream down"}`),
|
||||
}, nil)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
assert.JSONEq(t, `{"error":"upstream down"}`, w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("Connection"))
|
||||
}
|
||||
|
||||
func TestFilterHopByHopHeaders(t *testing.T) {
|
||||
filtered := filterHopByHopHeaders(map[string]string{
|
||||
"Connection": "close",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Keep-Alive": "timeout=5",
|
||||
"Proxy-Authenticate": "Basic",
|
||||
"Proxy-Authorization": "Basic token",
|
||||
"TE": "trailers",
|
||||
"Trailer": "Expires",
|
||||
"Upgrade": "websocket",
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-ID": "req-1",
|
||||
})
|
||||
|
||||
assert.Equal(t, map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-ID": "req-1",
|
||||
}, filtered)
|
||||
}
|
||||
|
||||
func TestProxyHandler_UnknownInterface_DoesNotGuessModel(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
providerSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
assert.Equal(t, "https://api.test.com/v1/unknown?trace=1", spec.URL)
|
||||
assert.JSONEq(t, `{"model":"p1/gpt-4","payload":true}`, string(spec.Body))
|
||||
return &conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusOK,
|
||||
Headers: map[string]string{"Content-Type": "application/json"},
|
||||
Body: []byte(`{"ok":true}`),
|
||||
}, nil
|
||||
})
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/unknown"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/unknown?trace=1", bytes.NewReader([]byte(`{"model":"p1/gpt-4","payload":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.JSONEq(t, `{"ok":true}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_InvalidJSON_UsesGatewayError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.JSONEq(t, `{"error":"请求体 JSON 格式错误","code":"INVALID_JSON"}`, w.Body.String())
|
||||
}
|
||||
|
||||
func TestProxyHandler_CrossProtocolMultimodal_Unsupported(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("anthropic_p", "claude").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "anthropic_p", Name: "Anthropic", APIKey: "sk-test", BaseURL: "https://api.anthropic.test", Protocol: "anthropic", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "anthropic_p", ModelName: "claude", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
body := []byte(`{"model":"anthropic_p/claude","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "UNSUPPORTED_MULTIMODAL")
|
||||
}
|
||||
|
||||
func TestProxyHandler_SameProtocolMultimodal_SmartPassthrough(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
routingSvc.EXPECT().RouteByModelName("p1", "gpt-4").Return(&domain.RouteResult{
|
||||
Provider: &domain.Provider{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
Model: &domain.Model{ID: "m1", ProviderID: "p1", ModelName: "gpt-4", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error) {
|
||||
assert.Contains(t, string(spec.Body), "image_url")
|
||||
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
|
||||
return &conversion.HTTPResponseSpec{
|
||||
StatusCode: http.StatusOK,
|
||||
Headers: map[string]string{"Content-Type": "application/json"},
|
||||
Body: []byte(`{"id":"r1","object":"chat.completion","model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`),
|
||||
}, nil
|
||||
})
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
statsSvc.EXPECT().Record(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
body := []byte(`{"model":"p1/gpt-4","messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader(body))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "p1/gpt-4")
|
||||
}
|
||||
|
||||
func TestProxyHandler_RawStreamPassthrough_PreservesSSEFrames(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
engine := setupProxyEngine(t)
|
||||
routingSvc := mocks.NewMockRoutingService(ctrl)
|
||||
providerSvc := mocks.NewMockProviderService(ctrl)
|
||||
providerSvc.EXPECT().List().Return([]domain.Provider{
|
||||
{ID: "p1", Name: "Test", APIKey: "sk-test", BaseURL: "https://api.test.com/v1", Protocol: "openai", Enabled: true},
|
||||
}, nil)
|
||||
client := mocks.NewMockProviderClient(ctrl)
|
||||
client.EXPECT().SendStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, spec conversion.HTTPRequestSpec) (*provider.StreamResponse, error) {
|
||||
assert.Contains(t, string(spec.Body), `"model":"gpt-4"`)
|
||||
ch := make(chan provider.StreamEvent, 3)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
ch <- provider.StreamEvent{Data: []byte("data: {\"model\":\"gpt-4\",\"choices\":[]}\n\n")}
|
||||
ch <- provider.StreamEvent{Data: []byte("data: [DONE]\n\n")}
|
||||
ch <- provider.StreamEvent{Done: true}
|
||||
}()
|
||||
return &provider.StreamResponse{StatusCode: http.StatusOK, Headers: map[string]string{"Content-Type": "text/event-stream"}, Events: ch}, nil
|
||||
})
|
||||
statsSvc := mocks.NewMockStatsService(ctrl)
|
||||
h := newTestProxyHandler(engine, client, routingSvc, providerSvc, statsSvc)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "protocol", Value: "openai"}, {Key: "path", Value: "/v1/chat/completions"}}
|
||||
c.Request = httptest.NewRequest("POST", "/openai/v1/chat/completions", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"stream":true}`)))
|
||||
|
||||
h.HandleProxy(c)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "data: {\"model\":\"gpt-4\",\"choices\":[]}\n\ndata: [DONE]\n\n", w.Body.String())
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"nex/backend/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// StatsHandler 统计处理器
|
||||
|
||||
26
backend/internal/handler/version_handler.go
Normal file
26
backend/internal/handler/version_handler.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"nex/backend/pkg/buildinfo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// VersionHandler 提供后端构建版本信息。
|
||||
type VersionHandler struct{}
|
||||
|
||||
// NewVersionHandler 创建版本信息处理器。
|
||||
func NewVersionHandler() *VersionHandler {
|
||||
return &VersionHandler{}
|
||||
}
|
||||
|
||||
// GetVersion 返回构建注入的版本元数据。
|
||||
func (h *VersionHandler) GetVersion(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"version": buildinfo.Version(),
|
||||
"commit": buildinfo.Commit(),
|
||||
"build_time": buildinfo.BuildTime(),
|
||||
})
|
||||
}
|
||||
31
backend/internal/handler/version_handler_test.go
Normal file
31
backend/internal/handler/version_handler_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVersionHandler_GetVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := NewVersionHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/version", nil)
|
||||
|
||||
h.GetVersion(c)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result map[string]string
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &result))
|
||||
assert.Equal(t, "dev", result["version"])
|
||||
assert.Equal(t, "unknown", result["commit"])
|
||||
assert.Equal(t, "unknown", result["build_time"])
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
pkgErrors "nex/backend/pkg/errors"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
// StreamConfig 流式处理配置
|
||||
@@ -42,6 +44,14 @@ type StreamEvent struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
// StreamResponse 表示上游流式 HTTP 响应。
|
||||
type StreamResponse struct {
|
||||
StatusCode int
|
||||
Headers map[string]string
|
||||
Body []byte
|
||||
Events <-chan StreamEvent
|
||||
}
|
||||
|
||||
// Client 协议无关的供应商客户端
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
@@ -50,19 +60,20 @@ type Client struct {
|
||||
}
|
||||
|
||||
// ProviderClient 供应商客户端接口
|
||||
//
|
||||
//go:generate go run go.uber.org/mock/mockgen -source=client.go -destination=../../tests/mocks/mock_provider_client.go -package=mocks
|
||||
type ProviderClient interface {
|
||||
Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*conversion.HTTPResponseSpec, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error)
|
||||
SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error)
|
||||
}
|
||||
|
||||
// NewClient 创建供应商客户端
|
||||
func NewClient() *Client {
|
||||
func NewClient(logger *zap.Logger) *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
logger: zap.L(),
|
||||
logger: pkglogger.WithModule(logger, "provider.client"),
|
||||
streamCfg: DefaultStreamConfig(),
|
||||
}
|
||||
}
|
||||
@@ -114,7 +125,7 @@ func (c *Client) Send(ctx context.Context, spec conversion.HTTPRequestSpec) (*co
|
||||
}
|
||||
|
||||
// SendStream 发送流式请求
|
||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (<-chan StreamEvent, error) {
|
||||
func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec) (*StreamResponse, error) {
|
||||
var bodyReader io.Reader
|
||||
if len(spec.Body) > 0 {
|
||||
bodyReader = bytes.NewReader(spec.Body)
|
||||
@@ -137,20 +148,29 @@ func (c *Client) SendStream(ctx context.Context, spec conversion.HTTPRequestSpec
|
||||
return nil, pkgErrors.ErrRequestSend.WithCause(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respHeaders := extractResponseHeaders(resp.Header)
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
defer resp.Body.Close()
|
||||
cancel()
|
||||
errBody, _ := io.ReadAll(resp.Body)
|
||||
if len(errBody) > 0 {
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d: %s", resp.StatusCode, string(errBody))
|
||||
errBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, pkgErrors.ErrResponseRead.WithCause(readErr)
|
||||
}
|
||||
return nil, fmt.Errorf("供应商返回错误: HTTP %d", resp.StatusCode)
|
||||
return &StreamResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Body: errBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
eventChan := make(chan StreamEvent, c.streamCfg.ChannelBufferSize)
|
||||
go c.readStream(streamCtx, cancel, resp.Body, eventChan)
|
||||
|
||||
return eventChan, nil
|
||||
return &StreamResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: respHeaders,
|
||||
Events: eventChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// readStream 读取 SSE 流
|
||||
@@ -183,10 +203,10 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
if isNetworkError(err) {
|
||||
c.logger.Error("流网络错误", zap.String("error", err.Error()))
|
||||
c.logger.Error("流网络错误", zap.Error(err))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("网络错误: %w", err)}
|
||||
} else {
|
||||
c.logger.Error("流读取错误", zap.String("error", err.Error()))
|
||||
c.logger.Error("流读取错误", zap.Error(err))
|
||||
eventChan <- StreamEvent{Error: fmt.Errorf("读取错误: %w", err)}
|
||||
}
|
||||
return
|
||||
@@ -203,15 +223,17 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
}
|
||||
|
||||
for {
|
||||
idx := bytes.Index(dataBuf, []byte("\n\n"))
|
||||
idx, sepLen := findSSEFrameSeparator(dataBuf)
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
rawEvent := dataBuf[:idx]
|
||||
dataBuf = dataBuf[idx+2:]
|
||||
frameEnd := idx + sepLen
|
||||
rawEvent := append([]byte(nil), dataBuf[:frameEnd]...)
|
||||
dataBuf = dataBuf[frameEnd:]
|
||||
|
||||
if bytes.Contains(rawEvent, []byte("data: [DONE]")) {
|
||||
if isSSEDoneFrame(rawEvent) {
|
||||
eventChan <- StreamEvent{Data: rawEvent}
|
||||
eventChan <- StreamEvent{Done: true}
|
||||
return
|
||||
}
|
||||
@@ -220,11 +242,66 @@ func (c *Client) readStream(ctx context.Context, cancel context.CancelFunc, body
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
if len(dataBuf) > 0 {
|
||||
eventChan <- StreamEvent{Data: dataBuf}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isSSEDoneFrame(frame []byte) bool {
|
||||
payload, ok := sseFrameDataPayload(frame)
|
||||
return ok && strings.TrimSpace(payload) == "[DONE]"
|
||||
}
|
||||
|
||||
func sseFrameDataPayload(frame []byte) (string, bool) {
|
||||
text := strings.TrimRight(string(frame), "\r\n")
|
||||
lines := strings.Split(text, "\n")
|
||||
var dataLines []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
value := strings.TrimPrefix(line, "data:")
|
||||
if strings.HasPrefix(value, " ") {
|
||||
value = value[1:]
|
||||
}
|
||||
dataLines = append(dataLines, value)
|
||||
}
|
||||
}
|
||||
if len(dataLines) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(dataLines, "\n"), true
|
||||
}
|
||||
|
||||
func extractResponseHeaders(header http.Header) map[string]string {
|
||||
respHeaders := make(map[string]string)
|
||||
for k, vs := range header {
|
||||
if len(vs) > 0 {
|
||||
respHeaders[k] = vs[0]
|
||||
}
|
||||
}
|
||||
return respHeaders
|
||||
}
|
||||
|
||||
func findSSEFrameSeparator(data []byte) (int, int) {
|
||||
lf := bytes.Index(data, []byte("\n\n"))
|
||||
crlf := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
switch {
|
||||
case lf < 0 && crlf < 0:
|
||||
return -1, 0
|
||||
case lf < 0:
|
||||
return crlf, 4
|
||||
case crlf < 0:
|
||||
return lf, 2
|
||||
case crlf <= lf:
|
||||
return crlf, 4
|
||||
default:
|
||||
return lf, 2
|
||||
}
|
||||
}
|
||||
|
||||
// isNetworkError 判断是否为网络相关错误
|
||||
func isNetworkError(err error) bool {
|
||||
if err == nil {
|
||||
|
||||
@@ -13,12 +13,13 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/conversion"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.httpClient)
|
||||
assert.Equal(t, 4096, client.streamCfg.InitialBufferSize)
|
||||
@@ -40,11 +41,12 @@ func TestClient_Send_Success(t *testing.T) {
|
||||
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
|
||||
_, err := w.Write([]byte(`{"id":"test","model":"gpt-4"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -64,11 +66,12 @@ func TestClient_Send_Success(t *testing.T) {
|
||||
func TestClient_Send_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
|
||||
_, err := w.Write([]byte(`{"error":{"message":"Invalid API key"}}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -82,7 +85,7 @@ func TestClient_Send_ErrorResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClient_Send_ConnectionError(t *testing.T) {
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: "http://localhost:1/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -99,7 +102,7 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -107,11 +110,13 @@ func TestClient_SendStream_CreatesChannel(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, eventChan)
|
||||
require.NotNil(t, streamResp)
|
||||
require.Equal(t, http.StatusOK, streamResp.StatusCode)
|
||||
require.NotNil(t, streamResp.Events)
|
||||
|
||||
for range eventChan {
|
||||
for range streamResp.Events {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +126,7 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -129,8 +134,10 @@ func TestClient_SendStream_ErrorResponse(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
_, err := client.SendStream(context.Background(), spec)
|
||||
assert.Error(t, err)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
assert.Equal(t, http.StatusInternalServerError, streamResp.StatusCode)
|
||||
}
|
||||
|
||||
func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
@@ -139,18 +146,21 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
|
||||
_, err = w.Write([]byte("data: {\"id\":\"1\",\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -158,24 +168,73 @@ func TestClient_SendStream_SSEEvents(t *testing.T) {
|
||||
Body: []byte(`{"model":"gpt-4","messages":[],"stream":true}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
if event.Done {
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
} else if event.Error != nil {
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
} else {
|
||||
default:
|
||||
dataEvents = append(dataEvents, event.Data)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(dataEvents), "expected exactly 2 data events from SSE stream")
|
||||
assert.Equal(t, 3, len(dataEvents), "expected 2 data frames plus DONE frame from SSE stream")
|
||||
assert.Contains(t, string(dataEvents[0]), "Hello")
|
||||
assert.Contains(t, string(dataEvents[1]), "World")
|
||||
assert.Contains(t, string(dataEvents[2]), "[DONE]")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
assert.Contains(t, string(dataEvents[0]), "\n\n")
|
||||
}
|
||||
|
||||
func TestClient_SendStream_DONEOnlyWhenDataPayloadEqualsDone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
_, err := w.Write([]byte("data: {\"text\":\"data: [DONE] is plain text\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, streamResp)
|
||||
|
||||
var dataEvents [][]byte
|
||||
var doneEvents int
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneEvents++
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
default:
|
||||
dataEvents = append(dataEvents, event.Data)
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, dataEvents, 2)
|
||||
assert.Contains(t, string(dataEvents[0]), "plain text")
|
||||
assert.Contains(t, string(dataEvents[1]), "[DONE]")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
}
|
||||
|
||||
@@ -188,7 +247,7 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -196,13 +255,13 @@ func TestClient_SendStream_ContextCancellation(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(ctx, spec)
|
||||
streamResp, err := client.SendStream(ctx, spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
cancel()
|
||||
|
||||
var gotError bool
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
gotError = true
|
||||
}
|
||||
@@ -214,11 +273,12 @@ func TestClient_Send_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"result":"ok"}`))
|
||||
_, err := w.Write([]byte(`{"result":"ok"}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/models",
|
||||
Method: "GET",
|
||||
@@ -237,16 +297,18 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -254,21 +316,22 @@ func TestClient_SendStream_SlowSSE(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var dataCount int
|
||||
var doneCount int
|
||||
for event := range eventChan {
|
||||
if event.Done {
|
||||
for event := range streamResp.Events {
|
||||
switch {
|
||||
case event.Done:
|
||||
doneCount++
|
||||
} else if event.Error != nil {
|
||||
case event.Error != nil:
|
||||
t.Fatalf("unexpected error: %v", event.Error)
|
||||
} else {
|
||||
default:
|
||||
dataCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, dataCount, "expected exactly 1 data event from slow SSE")
|
||||
assert.Equal(t, 2, dataCount, "expected 1 data frame plus DONE frame from slow SSE")
|
||||
assert.Equal(t, 1, doneCount, "expected exactly 1 done event from slow SSE")
|
||||
}
|
||||
|
||||
@@ -278,16 +341,18 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\ndata: {\"id\":\"2\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
_, err = w.Write([]byte("data: [DONE]\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -295,19 +360,19 @@ func TestClient_SendStream_SplitSSEEvents(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var dataEvents int
|
||||
var doneEvents int
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Done {
|
||||
doneEvents++
|
||||
} else {
|
||||
dataEvents++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, dataEvents, "expected exactly 2 data events from split SSE")
|
||||
assert.Equal(t, 3, dataEvents, "expected 2 data frames plus DONE frame from split SSE")
|
||||
assert.Equal(t, 1, doneEvents)
|
||||
}
|
||||
|
||||
@@ -363,19 +428,20 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, ok := w.(http.Flusher)
|
||||
require.True(t, ok)
|
||||
w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
_, err := w.Write([]byte("data: {\"id\":\"1\"}\n\n"))
|
||||
require.NoError(t, err)
|
||||
flusher.Flush()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if hijacker, ok := w.(http.Hijacker); ok {
|
||||
conn, _, _ := hijacker.Hijack()
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client := NewClient(zap.NewNop())
|
||||
spec := conversion.HTTPRequestSpec{
|
||||
URL: server.URL + "/v1/chat/completions",
|
||||
Method: "POST",
|
||||
@@ -383,11 +449,11 @@ func TestClient_SendStream_MidStreamNetworkError(t *testing.T) {
|
||||
Body: []byte(`{}`),
|
||||
}
|
||||
|
||||
eventChan, err := client.SendStream(context.Background(), spec)
|
||||
streamResp, err := client.SendStream(context.Background(), spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
var gotData bool
|
||||
for event := range eventChan {
|
||||
for event := range streamResp.Events {
|
||||
if event.Error != nil {
|
||||
} else if !event.Done {
|
||||
gotData = true
|
||||
|
||||
@@ -3,10 +3,11 @@ package repository
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ package repository
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"nex/backend/internal/config"
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type statsRepository struct {
|
||||
@@ -19,50 +19,46 @@ func NewStatsRepository(db *gorm.DB) StatsRepository {
|
||||
}
|
||||
|
||||
func (r *statsRepository) Record(providerID, modelName string) error {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
todayTime, _ := time.Parse("2006-01-02", today)
|
||||
now := time.Now()
|
||||
todayTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats config.UsageStats
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, todayTime).First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
stats = config.UsageStats{
|
||||
stats := config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
Date: todayTime,
|
||||
}
|
||||
return tx.Create(&stats).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Model(&stats).Update("request_count", gorm.Expr("request_count + 1")).Error
|
||||
})
|
||||
return r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "provider_id"},
|
||||
{Name: "model_name"},
|
||||
{Name: "date"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"request_count": gorm.Expr("request_count + 1"),
|
||||
}),
|
||||
}).Create(&stats).Error
|
||||
}
|
||||
|
||||
func (r *statsRepository) BatchUpdate(providerID, modelName string, date time.Time, delta int) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var stats config.UsageStats
|
||||
err := tx.Where("provider_id = ? AND model_name = ? AND date = ?",
|
||||
providerID, modelName, date).First(&stats).Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return tx.Create(&config.UsageStats{
|
||||
stats := config.UsageStats{
|
||||
ProviderID: providerID,
|
||||
ModelName: modelName,
|
||||
RequestCount: delta,
|
||||
Date: date,
|
||||
}).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Model(&stats).
|
||||
Update("request_count", gorm.Expr("request_count + ?", delta)).Error
|
||||
})
|
||||
return r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "provider_id"},
|
||||
{Name: "model_name"},
|
||||
{Name: "date"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"request_count": gorm.Expr("request_count + ?", delta),
|
||||
}),
|
||||
}).Create(&stats).Error
|
||||
}
|
||||
|
||||
func (r *statsRepository) Query(providerID, modelName string, startDate, endDate *time.Time) ([]domain.UsageStats, error) {
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
"errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type modelService struct {
|
||||
@@ -108,8 +112,12 @@ func (s *modelService) Delete(id string) error {
|
||||
func (s *modelService) checkDuplicateModelName(providerID, modelName, excludeID string) error {
|
||||
existing, err := s.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil // 未找到,不重复
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
if excludeID != "" && existing.ID == excludeID {
|
||||
return nil // 排除自身
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package service
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
"nex/backend/pkg/modelid"
|
||||
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type RoutingCache struct {
|
||||
@@ -27,13 +29,15 @@ func NewRoutingCache(
|
||||
return &RoutingCache{
|
||||
modelRepo: modelRepo,
|
||||
providerRepo: providerRepo,
|
||||
logger: logger,
|
||||
logger: pkglogger.WithModule(logger, "service.routing_cache"),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
if provider, ok := v.(*domain.Provider); ok {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
|
||||
provider, err := c.providerRepo.GetByID(id)
|
||||
@@ -42,7 +46,9 @@ func (c *RoutingCache) GetProvider(id string) (*domain.Provider, error) {
|
||||
}
|
||||
|
||||
if v, ok := c.providers.Load(id); ok {
|
||||
return v.(*domain.Provider), nil
|
||||
if provider, ok := v.(*domain.Provider); ok {
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.providers.Store(id, provider)
|
||||
@@ -53,7 +59,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
|
||||
key := providerID + "/" + modelName
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
if model, ok := v.(*domain.Model); ok {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
model, err := c.modelRepo.FindByProviderAndModelName(providerID, modelName)
|
||||
@@ -62,7 +70,9 @@ func (c *RoutingCache) GetModel(providerID, modelName string) (*domain.Model, er
|
||||
}
|
||||
|
||||
if v, ok := c.models.Load(key); ok {
|
||||
return v.(*domain.Model), nil
|
||||
if model, ok := v.(*domain.Model); ok {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.models.Store(key, model)
|
||||
@@ -96,7 +106,12 @@ func (c *RoutingCache) invalidateModelsByProvider(providerID string) {
|
||||
prefix := providerID + "/"
|
||||
count := 0
|
||||
c.models.Range(func(key, value interface{}) bool {
|
||||
if strings.HasPrefix(key.(string), prefix) {
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(keyStr, prefix) {
|
||||
c.models.Delete(key)
|
||||
count++
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockModelRepo struct {
|
||||
@@ -189,7 +189,8 @@ func TestRoutingCache_InvalidateProvider_CascadingModels(t *testing.T) {
|
||||
|
||||
var openaiCount, anthropicCount int
|
||||
cache.models.Range(func(key, value interface{}) bool {
|
||||
if key.(string) == "anthropic/claude" {
|
||||
keyStr, ok := key.(string)
|
||||
if ok && keyStr == "anthropic/claude" {
|
||||
anthropicCount++
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
type routingService struct {
|
||||
|
||||
@@ -3,11 +3,12 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestProviderService_Update(t *testing.T) {
|
||||
@@ -119,7 +120,7 @@ func TestModelService_Delete_NotFound(t *testing.T) {
|
||||
|
||||
func TestStatsService_Aggregate_Default(t *testing.T) {
|
||||
statsRepo := repository.NewStatsRepository(nil)
|
||||
buffer := NewStatsBuffer(statsRepo, nil)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
stats := []domain.UsageStats{
|
||||
@@ -132,7 +133,9 @@ func TestStatsService_Aggregate_Default(t *testing.T) {
|
||||
|
||||
totalCount := 0
|
||||
for _, r := range result {
|
||||
totalCount += r["request_count"].(int)
|
||||
count, ok := r["request_count"].(int)
|
||||
require.True(t, ok)
|
||||
totalCount += count
|
||||
}
|
||||
assert.Equal(t, 15, totalCount)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -13,8 +16,6 @@ import (
|
||||
|
||||
testHelpers "nex/backend/tests"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
"nex/backend/internal/repository"
|
||||
appErrors "nex/backend/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -318,7 +319,8 @@ func TestStatsService_Aggregate_ByModel(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "model")
|
||||
|
||||
@@ -379,7 +381,8 @@ func TestStatsService_Aggregate_ByDate(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := setupServiceTestDB(t)
|
||||
statsRepo := repository.NewStatsRepository(db)
|
||||
buffer := NewStatsBuffer(statsRepo, nil); svc := NewStatsService(statsRepo, buffer)
|
||||
buffer := NewStatsBuffer(statsRepo, zap.NewNop())
|
||||
svc := NewStatsService(statsRepo, buffer)
|
||||
|
||||
result := svc.Aggregate(tt.stats, "date")
|
||||
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/repository"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/repository"
|
||||
pkglogger "nex/backend/pkg/logger"
|
||||
)
|
||||
|
||||
type StatsBuffer struct {
|
||||
@@ -46,7 +48,7 @@ func NewStatsBuffer(
|
||||
) *StatsBuffer {
|
||||
b := &StatsBuffer{
|
||||
statsRepo: statsRepo,
|
||||
logger: logger,
|
||||
logger: pkglogger.WithModule(logger, "service.stats_buffer"),
|
||||
flushInterval: 5 * time.Second,
|
||||
flushThreshold: 100,
|
||||
stopCh: make(chan struct{}),
|
||||
@@ -66,13 +68,21 @@ func (b *StatsBuffer) Increment(providerID, modelName string) {
|
||||
|
||||
var counter *int64
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter = v.(*int64)
|
||||
if existing, ok := v.(*int64); ok {
|
||||
counter = existing
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
val := int64(0)
|
||||
counter = &val
|
||||
actual, loaded := b.counters.LoadOrStore(key, counter)
|
||||
if loaded {
|
||||
counter = actual.(*int64)
|
||||
existing, ok := actual.(*int64)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter = existing
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,13 +126,20 @@ func (b *StatsBuffer) flush() {
|
||||
|
||||
var entries []statEntry
|
||||
b.counters.Range(func(key, value interface{}) bool {
|
||||
keyStr := key.(string)
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
parts := strings.Split(keyStr, "/")
|
||||
if len(parts) != 3 {
|
||||
return true
|
||||
}
|
||||
|
||||
counter := value.(*int64)
|
||||
counter, ok := value.(*int64)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
count := atomic.SwapInt64(counter, 0)
|
||||
|
||||
if count > 0 {
|
||||
@@ -142,8 +159,17 @@ func (b *StatsBuffer) flush() {
|
||||
|
||||
success := 0
|
||||
for _, entry := range entries {
|
||||
date, _ := time.Parse("2006-01-02", entry.date)
|
||||
err := b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
|
||||
date, err := time.Parse("2006-01-02", entry.date)
|
||||
if err != nil {
|
||||
b.logger.Error("解析统计日期失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
zap.String("model_name", entry.modelName),
|
||||
zap.String("date", entry.date),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
err = b.statsRepo.BatchUpdate(entry.providerID, entry.modelName, date, int(entry.count))
|
||||
if err != nil {
|
||||
b.logger.Error("批量更新统计失败",
|
||||
zap.String("provider_id", entry.providerID),
|
||||
@@ -153,9 +179,11 @@ func (b *StatsBuffer) flush() {
|
||||
|
||||
key := entry.providerID + "/" + entry.modelName + "/" + entry.date
|
||||
if v, ok := b.counters.Load(key); ok {
|
||||
counter := v.(*int64)
|
||||
counter, ok := v.(*int64)
|
||||
if ok {
|
||||
atomic.AddInt64(counter, entry.count)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"nex/backend/internal/domain"
|
||||
)
|
||||
|
||||
type mockStatsRepo struct {
|
||||
@@ -58,8 +58,10 @@ func TestStatsBuffer_Increment(t *testing.T) {
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count += atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(3), count)
|
||||
@@ -82,8 +84,10 @@ func TestStatsBuffer_ConcurrentIncrement(t *testing.T) {
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(100), count)
|
||||
@@ -161,8 +165,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
|
||||
|
||||
var beforeCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
beforeCount = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), beforeCount)
|
||||
@@ -171,8 +177,10 @@ func TestStatsBuffer_SwapInt64(t *testing.T) {
|
||||
|
||||
var afterCount int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
afterCount = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(0), afterCount)
|
||||
@@ -190,8 +198,10 @@ func TestStatsBuffer_FailRetry(t *testing.T) {
|
||||
|
||||
var count int64
|
||||
buffer.counters.Range(func(key, value interface{}) bool {
|
||||
counter := value.(*int64)
|
||||
counter, ok := value.(*int64)
|
||||
if ok {
|
||||
count = atomic.LoadInt64(counter)
|
||||
}
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, int64(2), count)
|
||||
|
||||
22
backend/pkg/buildinfo/buildinfo.go
Normal file
22
backend/pkg/buildinfo/buildinfo.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package buildinfo
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "unknown"
|
||||
buildTime = "unknown"
|
||||
)
|
||||
|
||||
// Version 返回构建注入的版本号。
|
||||
func Version() string {
|
||||
return version
|
||||
}
|
||||
|
||||
// Commit 返回构建注入的 git commit。
|
||||
func Commit() string {
|
||||
return commit
|
||||
}
|
||||
|
||||
// BuildTime 返回构建注入的构建时间。
|
||||
func BuildTime() string {
|
||||
return buildTime
|
||||
}
|
||||
17
backend/pkg/buildinfo/buildinfo_test.go
Normal file
17
backend/pkg/buildinfo/buildinfo_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package buildinfo
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaults(t *testing.T) {
|
||||
if Version() == "" {
|
||||
t.Fatal("Version() 不应为空")
|
||||
}
|
||||
|
||||
if Commit() == "" {
|
||||
t.Fatal("Commit() 不应为空")
|
||||
}
|
||||
|
||||
if BuildTime() == "" {
|
||||
t.Fatal("BuildTime() 不应为空")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
@@ -70,22 +71,11 @@ func AsAppError(err error) (*AppError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var appErr *AppError
|
||||
if ok := is(err, &appErr); ok {
|
||||
return appErr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func is(err error, target interface{}) bool {
|
||||
// 简单的类型断言
|
||||
if e, ok := err.(*AppError); ok {
|
||||
// 直接赋值
|
||||
switch t := target.(type) {
|
||||
case **AppError:
|
||||
*t = e
|
||||
return true
|
||||
var appErr *AppError
|
||||
if !stderrors.As(err, &appErr) {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
return appErr, true
|
||||
}
|
||||
|
||||
@@ -104,7 +104,8 @@ func TestPredefinedErrors(t *testing.T) {
|
||||
|
||||
func TestAsAppError(t *testing.T) {
|
||||
t.Run("nil输入", func(t *testing.T) {
|
||||
_, ok := AsAppError(nil)
|
||||
appErr, ok := AsAppError(nil)
|
||||
assert.Nil(t, appErr)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
@@ -122,7 +123,8 @@ func TestAsAppError(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("非AppError类型", func(t *testing.T) {
|
||||
_, ok := AsAppError(errors.New("普通错误"))
|
||||
appErr, ok := AsAppError(errors.New("普通错误"))
|
||||
assert.Nil(t, appErr)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package logger
|
||||
|
||||
import "go.uber.org/zap"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ctxKey struct{}
|
||||
|
||||
const requestIDKey = "request_id"
|
||||
|
||||
// WithRequestID 向 logger 添加 request_id 字段
|
||||
func WithRequestID(logger *zap.Logger, requestID string) *zap.Logger {
|
||||
return logger.With(zap.String("request_id", requestID))
|
||||
return logger.With(zap.String(requestIDKey, requestID))
|
||||
}
|
||||
|
||||
// WithContext 向 logger 添加多个自定义字段
|
||||
func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for k, v := range fields {
|
||||
@@ -15,3 +22,37 @@ func WithContext(logger *zap.Logger, fields map[string]interface{}) *zap.Logger
|
||||
}
|
||||
return logger.With(zapFields...)
|
||||
}
|
||||
|
||||
func RequestIDFromGinContext(c *gin.Context) zap.Field {
|
||||
requestID, exists := c.Get("request_id")
|
||||
if !exists {
|
||||
return zap.Skip()
|
||||
}
|
||||
if id, ok := requestID.(string); ok {
|
||||
return RequestID(id)
|
||||
}
|
||||
return zap.Skip()
|
||||
}
|
||||
|
||||
func RequestIDFromContext(ctx context.Context) zap.Field {
|
||||
requestID := ctx.Value(ctxKey{})
|
||||
if requestID == nil {
|
||||
return zap.Skip()
|
||||
}
|
||||
if id, ok := requestID.(string); ok {
|
||||
return RequestID(id)
|
||||
}
|
||||
return zap.Skip()
|
||||
}
|
||||
|
||||
func ContextWithRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, ctxKey{}, requestID)
|
||||
}
|
||||
|
||||
func LoggerFromContext(ctx context.Context, baseLogger *zap.Logger) *zap.Logger {
|
||||
field := RequestIDFromContext(ctx)
|
||||
if field == zap.Skip() {
|
||||
return baseLogger
|
||||
}
|
||||
return baseLogger.With(field)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user